diff --git a/README.md b/README.md index 515ccf5..94fc1b9 100644 --- a/README.md +++ b/README.md @@ -265,6 +265,12 @@ print(messages) # {'Numbers': [('float', 'a'), ('float', 'b')], 'Result': [('float', 'value')]} ``` +## Training + +With the SDK, you can also train models and use them when calling the service. + + + --- ###### 1 This method uses a call to a paid smart contract function. diff --git a/docs/main/account.md b/docs/main/account.md index c951c99..d9202fd 100644 --- a/docs/main/account.md +++ b/docs/main/account.md @@ -26,7 +26,7 @@ is extended by: - #### description -`TransactionError` is a custom exception class that is raised when an Ethereum transaction receipt has a status of 0. +`TransactionError` is a custom exception class that is raised when an Ethereum transaction receipt has a status of 0. This indicates that the transaction failed. Can provide a custom message. Optionally includes receipt #### attributes diff --git a/docs/main/client_lib_generator.md b/docs/main/client_lib_generator.md index bf45715..31fe37c 100644 --- a/docs/main/client_lib_generator.md +++ b/docs/main/client_lib_generator.md @@ -6,6 +6,10 @@ Entities: 1. [ClientLibGenerator](#class-clientlibgenerator) - [\_\_init\_\_](#__init__) - [generate_client_library](#generate_client_library) + - [generate_directories_by_params](#generate_directories_by_params) + - [create_service_client_libraries_path](#create_service_client_libraries_path) + - [receive_proto_files](#receive_proto_files) + - [training_added](#training_added) ### Class `ClientLibGenerator` @@ -36,6 +40,7 @@ Initializes a new instance of the class. Initializes the attributes by arguments - `metadata_provider` (StorageProvider): An instance of the `StorageProvider` class. - `org_id` (str): The organization ID of the service. - `service_id` (str): The service ID. +- `protodir` (Path | None): The directory where the .proto files are located. Default is _None_. ###### returns: @@ -51,3 +56,42 @@ Generates client library stub files based on specified organization and service ###### returns: - _None_ + +#### `generate_directories_by_params` + +Generates directories for client library in the `~/.snet` directory based on organization and +service ids using the `create_service_client_libraries_path` method. + +###### returns: + +- _None_ + +#### `create_service_client_libraries_path` + +Creates a directory for client library in the `~/.snet` directory based on organization and +service ids. + +###### returns: + +- _None_ + +#### `receive_proto_files` + +Receives .proto files from IPFS or FileCoin based on service metadata and extracts them to a +given directory. + +###### returns: + +- _None_ + +###### raises: + +- Exception: if the directory for storing proto files is not found. + +#### `training_added` + +Checks whether training is used in the service .proto file. + +###### returns: + +- _True_ if training is used in the service .proto file, _False_ otherwise. \ No newline at end of file diff --git a/docs/main/init.md b/docs/main/init.md index d2c9f0a..7a82f64 100644 --- a/docs/main/init.md +++ b/docs/main/init.md @@ -7,7 +7,6 @@ Entities: - [\_\_init\_\_](#__init__) - [create_service_client](#create_service_client) - [get_service_stub](#get_service_stub) - - [get_path_to_pb_files](#get_path_to_pb_files) - [get_module_by_keyword](#get_module_by_keyword) - [get_service_metadata](#get_service_metadata) - [_get_first_group](#_get_first_group) @@ -62,7 +61,7 @@ contract. Instantiates the Account object with the specified Web3 client, SDK co If `force_update` is True or if there are no gRPC stubs for the given service, the proto files are loaded and compiled using the `generate_client_library()` method of the `ClientLibGenerator` class instance. -It then initializes `payment_channel_management_strategy` to `DefaultPaymentStrategy` if it is not specified. +It then initializes `payment_strategy` to `DefaultPaymentStrategy` if it is not specified. It also sets the `options` dictionary with some default values. If `self._metadata_provider` is not specified it is initialized by `IPFSMetadataProvider`. It also gets the service stub using the `self.get_service_stub` method and the pb2 module using the `self.get_module_by_keyword` method. Finally, it creates a new instance @@ -73,7 +72,7 @@ of the `ServiceClient` class with all the required parameters, which is then ret - `org_id` (str): The ID of the organization. - `service_id` (str): The ID of the service. - `group_name` (str): The name of the payment group. Defaults to _None_. -- `payment_channel_management_strategy` (PaymentStrategy): The payment channel management strategy. Defaults to _None_. +- `payment_strategy` (PaymentStrategy): The payment channel management strategy. Defaults to _None_. - `free_call_auth_token_bin` (str): The free call authentication token in binary format. Defaults to _None_. - `free_call_token_expiry_block` (int): The block number when the free call token expires. Defaults to _None_. - `options` (dict): Additional options for the service client. Defaults to _None_. @@ -100,19 +99,6 @@ Retrieves the gRPC service stub for the given organization and service ID. - Exception: If an error occurs while importing a module. -#### `get_path_to_pb_files` - -Returns the path to the directory containing the protobuf files for a given organization and service. - -###### args: - -- `org_id` (str): The ID of the organization. -- `service_id` (str): The ID of the service. - -###### returns: - -- The path to the directory containing the protobuf files. (str) - #### `get_module_by_keyword` Retrieves the module name from the given organization ID, service ID, and keyword. diff --git a/docs/main/service_client.md b/docs/main/service_client.md index 4b44fdb..f3f23ab 100644 --- a/docs/main/service_client.md +++ b/docs/main/service_client.md @@ -6,10 +6,10 @@ Entities: 1. [ServiceClient](#class-serviceclient) - [\_\_init\_\_](#__init__) - [call_rpc](#call_rpc) + - [_get_service_stub](#_get_service_stub) - [_generate_grpc_stub](#_generate_grpc_stub) - [get_grpc_base_channel](#get_grpc_base_channel) - [_get_grpc_channel](#_get_grpc_channel) - - [_get_service_call_metadata](#_get_service_call_metadata) - [_filter_existing_channels_from_new_payment_channels](#_filter_existing_channels_from_new_payment_channels) - [load_open_channels](#load_open_channels) - [get_current_block_number](#get_current_block_number) @@ -23,10 +23,11 @@ Entities: - [generate_training_signature](#generate_training_signature) - [get_free_call_config](#get_free_call_config) - [get_service_details](#get_service_details) + - [training](#training) + - [_get_training_model_id](#_get_training_model_id) - [get_concurrency_flag](#get_concurrency_flag) - [get_concurrency_token_and_channel](#get_concurrency_token_and_channel) - [set_concurrency_token_and_channel](#set_concurrency_token_and_channel) - - [get_path_to_pb_files](#get_path_to_pb_files) - [get_services_and_messages_info](#get_services_and_messages_info) - [get_services_and_messages_info_as_pretty_string](#get_services_and_messages_info_as_pretty_string) @@ -58,7 +59,7 @@ the `PaymentStrategy` inheritor classes. - `payment_channel_provider` (PaymentChannelProvider): An instance of the `PaymentChannelProvider` class for working with channels and interacting with MPE. - `payment_channel_state_service_client` (Any): Stub for interacting with PaymentChannelStateService via gRPC. -- `service` (Any): The gRPC service stub instance. +- `service_stubs` (Any): The gRPC service stubs. - `pb2_module` (ModuleType): The imported protobuf module. - `payment_channels` (list[PaymentChannel]): The list of payment channels. - `last_read_block` (int): The last read block number. @@ -66,6 +67,8 @@ working with channels and interacting with MPE. SingularityNetToken contracts. - `sdk_web3` (Web3): The `Web3` instance. - `mpe_address` (str): The MPE contract address. +- `path_to_pb_files` (Path): The path to the protobuf files. +- `__training` (Training): An instance of the `Training` class. #### methods @@ -79,7 +82,7 @@ Initializes a new instance of the class. - `service_id` (str): The ID of the service. - `service_metadata` (MPEServiceMetadata): The metadata for the service. - `group` (dict): The payment group from the service metadata. -- `service_stub` (ServiceStub): The gRPC service stub. +- `service_stubs` (list[ServiceStub]): The gRPC service stubs. - `payment_strategy` (PaymentStrategy): The payment channel management strategy. - `options` (dict): Additional options for the service client. - `mpe_contract` (MPEContract): The MPE contract instance. @@ -87,6 +90,8 @@ Initializes a new instance of the class. - `sdk_web3` (Web3): The `Web3` instance. - `pb2_module` (str | ModuleType): The module containing the gRPC message definitions. - `payment_channel_provider` (PaymentChannelProvider): The payment channel provider instance. +- `path_to_pb_files` (Path): The path to the protobuf files. +- `training_added` (bool): Whether training enabled on the service or not. ###### returns: @@ -107,6 +112,18 @@ that are passed to the called method as arguments. - The response from the RPC method call. (Any) +#### `_get_service_stub` + +Generates a gRPC stub instance for all the service stubs and returns one which matches the rpc name. + +###### args: + +- `rpc_name` (str): The name of the RPC method to call. + +###### returns: + +- service_stub (Any): The gRPC service stub. + #### `_generate_grpc_stub` Generates a gRPC stub instance for the given service stub. @@ -148,16 +165,6 @@ a ValueError is raised with an error message. - ValueError: If the scheme in the service metadata is neither "http" nor "https". -#### `_get_service_call_metadata` - -Retrieves the metadata required for making a service call using the payment strategy. - -###### returns: - -- Payment metadata. (list[tuple[str, Any]]) - - - #### `_filter_existing_channels_from_new_payment_channels` Filters the new channel list so that only those that are not yet among the existing ones remain, @@ -300,6 +307,30 @@ Retrieves the details of the service. - A tuple containing the organization ID, service ID, group ID, and the first endpoint for the group. (tuple[str, str, str, str]) +#### `training` + +Property that returns the training object associated with the service. + +###### returns: + +- The training object associated with the service. (Training) + +###### raises: + +- NoTrainingException: If training is not implemented for the service. + +#### `_get_training_model_id` + +Converts model ID from `str` to stub object. + +###### args: + +- `model_id` (str): The model ID to convert. + +###### returns: + +- The stub object for the model ID. (Any) + #### `get_concurrency_flag` Returns the value of the `concurrency` option from the `self.options` dict. @@ -330,19 +361,6 @@ Sets the concurrency token and channel for the payment strategy. - _None_ -#### `get_path_to_pb_files` - -Returns the path to the directory containing the protobuf files for a given organization and service. - -###### args: - -- `org_id` (str): The ID of the organization. -- `service_id` (str): The ID of the service. - -###### returns: - -- The path to the directory containing the protobuf files. (str) - #### `get_services_and_messages_info` Retrieves information about services and messages defined in a protobuf file. diff --git a/docs/payment_strategies/freecall_payment_strategy.md b/docs/payment_strategies/freecall_payment_strategy.md index a563660..67424ac 100644 --- a/docs/payment_strategies/freecall_payment_strategy.md +++ b/docs/payment_strategies/freecall_payment_strategy.md @@ -42,7 +42,7 @@ _Note_: If any exception occurs during the process, it returns False. #### `get_payment_metadata` -Retrieves the payment metadata for a service client with the field `snet-paument-type` equals to `free-call` +Retrieves the payment metadata for a service client with the field `snet-payment-type` equals to `free-call` using the provided free call configuration. ###### args: diff --git a/docs/payment_strategies/paidcall_payment_strategy.md b/docs/payment_strategies/paidcall_payment_strategy.md index 8dff3f9..b7b48ab 100644 --- a/docs/payment_strategies/paidcall_payment_strategy.md +++ b/docs/payment_strategies/paidcall_payment_strategy.md @@ -15,7 +15,7 @@ Entities: extends: `PaymentStrategy` -is extended by: - +is extended by: `TrainingPaymentStrategy` #### description @@ -57,7 +57,7 @@ Returns the price of the service call using service client. #### `get_payment_metadata` -Creates and returns the payment metadata for a service client with the field `snet-paument-type` equals to `escrow`. +Creates and returns the payment metadata for a service client with the field `snet-payment-type` equals to `escrow`. ###### args: diff --git a/docs/payment_strategies/prepaid_payment_strategy.md b/docs/payment_strategies/prepaid_payment_strategy.md index 7e5094f..3f005ba 100644 --- a/docs/payment_strategies/prepaid_payment_strategy.md +++ b/docs/payment_strategies/prepaid_payment_strategy.md @@ -61,7 +61,7 @@ Returns the price of the service calls using service client. #### `get_payment_metadata` -Creates and returns the payment metadata for a service client with the field `snet-paument-type` equals +Creates and returns the payment metadata for a service client with the field `snet-payment-type` equals to `prepaid-call`. ###### args: diff --git a/docs/payment_strategies/training_payment_strategy.md b/docs/payment_strategies/training_payment_strategy.md new file mode 100644 index 0000000..b8942fb --- /dev/null +++ b/docs/payment_strategies/training_payment_strategy.md @@ -0,0 +1,95 @@ +## module: sdk.payment_strategies.training_payment_strategy + +[Link](https://github.com/singnet/snet-sdk-python/blob/master/snet/sdk/payment_strategies/training_payment_strategy.py) to GitHub + +Entities: +1. [TrainingPaymentStrategy](#class-trainingpaymentstrategy) + - [\_\_init\_\_](#__init__) + - [get_price](#get_price) + - [set_price](#set_price) + - [get_model_id](#get_model_id) + - [set_model_id](#set_model_id) + - [get_payment_metadata](#get_payment_metadata) + +### Class `TrainingPaymentStrategy` + +extends: `PaidCallPaymentStrategy` + +is extended by: - + +#### description + +The `TrainingPaymentStrategy` class extends `PaidCallPaymentStrategy` class. The difference from the +parent class is that for training, the call price is a dynamic value, and can be set using the appropriate +setter before the call. + +#### attributes + +- `_call_price` (int): The call price. Defaults to -1 (means that no price has been set). +- `_train_model_id` (str): The training model id. Defaults to empty string. + +#### methods + +#### `__init__` + +Initializes a new instance of the class. + +###### returns: + +- _None_ + +#### `get_price` + +Returns the price of the service call - `_call_price` value. + +###### returns: + +- The price of the service call. (int) + +###### raises: + +- `Exception`: If no price has been set. + +#### `set_price` + +Sets the price of the service call. + +###### args: + +- `call_price` (int): The price of the service call. + +###### returns: + +- _None_ + +#### `get_model_id` + +Returns the training model id - `_train_model_id` value. + +###### returns: + +- The training model id. (str) + +#### `set_model_id` + +Sets the training model id. + +###### args: + +- `model_id` (str): The training model id. + +###### returns: + +- _None_ + +#### `get_payment_metadata` + +Creates and returns the payment metadata for a service client with the field `snet-payment-type` equals to `train-call`. + +###### args: + +- `service_client` (ServiceClient): The service client object. + +###### returns: + +- The payment metadata. (list[tuple[str, str]]) diff --git a/docs/snet-sdk-python-documentation.md b/docs/snet-sdk-python-documentation.md index 841495c..844108d 100644 --- a/docs/snet-sdk-python-documentation.md +++ b/docs/snet-sdk-python-documentation.md @@ -32,11 +32,15 @@ A getting started guide for the SNET SDK for Python is available [here](https:// 3. [freecall_payment_strategy](payment_strategies/freecall_payment_strategy.md) 4. [paidcall_payment_strategy](payment_strategies/paidcall_payment_strategy.md) 5. [prepaid_payment_strategy](payment_strategies/prepaid_payment_strategy.md) + 6. [training_payment_strategy](payment_strategies/training_payment_strategy.md) 10. utils 1. [utils](utils/utils.md) 2. [ipfs_utils](utils/ipfs_utils.md) + 3. [call_utils](utils/call_utils.md) 11. training 1. [training](training/training.md) + 2. [responses](training/responses.md) + 3. [exceptions](training/exceptions.md) diff --git a/docs/training/exceptions.md b/docs/training/exceptions.md new file mode 100644 index 0000000..07799fb --- /dev/null +++ b/docs/training/exceptions.md @@ -0,0 +1,56 @@ +## module: sdk.training.exceptions + +[Link](https://github.com/singnet/snet-sdk-python/blob/master/snet/sdk/training/exceptions.py) to GitHub + +Entities: +1. [WrongDatasetException](#class-wrongdatasetexception) +2. [WrongMethodException](#class-wrongmethodexception) +3. [NoTrainingException](#class-notrainingexception) +4. [GRPCException](#class-grpcexception) +5. [NoSuchModelException](#class-nosuchmodelexception) + +### Class `WrongDatasetException` + +extends: `Exception` + +is extended by: - + +#### description + +This exception can be thrown when `_check_dataset` method of the `Training` class is called. +If the dataset is not compliant, this exception will be thrown, detailing what incompatibilities the dataset has. + +### Class `WrongMethodException` + +extends: `Exception` + +is extended by: - + +#### description + +This exception is thrown when you try to call `Training` methods with the wrong grpc method name. + +### Class `NoTrainingException` + +This exception is thrown while calling the `training` field of the `ServiceClient` class and the training is not +enabled in the service. + +### Class `GRPCException` + +extends: `RpcError` + +is extended by: - + +#### description + +This exception is thrown when there is an error in the grpc call. + +### Class `NoSuchModelException` + +extends: `Exception` + +is extended by: - + +#### description + +This exception is thrown when you try to call `Training` methods with the non-existent model id. diff --git a/docs/training/responses.md b/docs/training/responses.md new file mode 100644 index 0000000..9b90694 --- /dev/null +++ b/docs/training/responses.md @@ -0,0 +1,136 @@ +## module: sdk.training.responses + +[Link](https://github.com/singnet/snet-sdk-python/blob/master/snet/sdk/training/responses.py) to GitHub + +Entities: +1. [ModelMethodMessage](#class-modelmethodmessage) +2. [ModelStatus](#class-modelstatus) +3. [Model](#class-model) +4. [TrainingMetadata](#class-trainingmetadata) +5. [MethodMetadata](#class-methodmetadata) +6. [to_string](#function-to_string) + +### Class `ModelMethodMessage` + +extends: `Enum` + +is extended by: - + +#### description + +This is an `enum` that represents the available methods that can be called in the training grpc service. +It is used in the authorization messages. + +#### members + +- `CreateModel` +- `ValidateModelPrice` +- `TrainModelPrice` +- `DeleteModel` +- `GetTrainingMetadata` +- `GetAllModels` +- `GetModel` +- `UpdateModel` +- `GetMethodMetadata` +- `UploadAndValidate` +- `ValidateModel` +- `TrainModel` + +### Class `ModelStatus` + +extends: `Enum` + +is extended by: - + +#### description + +This is an `enum` that represents the status of a model. It is used to convert status object in the grpc call +response to a readable object. + +#### members + +- `CREATED` +- `VALIDATING` +- `VALIDATED` +- `TRAINING` +- `READY_TO_USE` +- `ERRORED` +- `DELETED` + +### Class `Model` + +extends: - + +is extended by: - + +#### description + +It is a data class that represents a model. It is used to convert model object in the grpc call response to a +readable object. + +#### attributes + +- `model_id` (str): The id of the model. +- `status` (ModelStatus): The status of the model. +- `created_date` (str): The date when the model was created. +- `updated_date` (str): The date when the model was updated. +- `name` (str): The name of the model. +- `description` (str): The description of the model. +- `grpc_method_name` (str): The name of the gRPC method for which the model was created. +- `grpc_service_name` (str): The name of the gRPC service. +- `address_list` (list[str]): A list of addresses with the access of the model. +- `is_public` (bool): Whether the model is publicly accessible. +- `training_data_link` (str): The link to the training data (not used in SDK). +- `created_by_address` (str): The address of the user wallet who created the model. +- `updated_by_address` (str): The address of the user wallet who updated the model. + +### Class `TrainingMetadata` + +extends: - + +is extended by: - + +#### description + +It is a data class that represents the training metadata. It is used to convert training metadata object in the +grpc call response to a readable object. + +#### attributes + +- `training_enabled` (bool): Whether training is enabled on the service. +- `training_in_proto` (bool): Whether training is in proto format. +- `training_methods` (dict[str, list[str]]): Dictionary of the following form: rpc service name - list of rpc methods. + +### Class `MethodMetadata` + +extends: - + +is extended by: - + +#### description + +It is a data class that represents the method metadata. It is used to convert method metadata object in the +grpc call response to a readable object. + +#### attributes + +- `default_model_id` (str): The default model id. +- `max_models_per_user` (int): The maximum number of models per user. +- `dataset_max_size_mb` (int): The maximum size of the dataset in MB. +- `dataset_max_count_files` (int): The maximum number of files in the dataset. +- `dataset_max_size_single_file_mb` (int): The maximum size of a single file in the dataset in MB. +- `dataset_files_type` (str): Allowed types of files in the dataset. (example: "jpg,png,mp3") +- `dataset_type` (str): The type of the dataset. (example: "zip,tar") +- `dataset_description` (str): Additional free-form requirements. + +### Function `to_string` + +Converts a data object to a string where each attribute is on a separate line with a name. + +#### args: + +- `obj` (Any): The data object to convert to a string. + +#### returns: + +- The string representation of the data object. (str) diff --git a/docs/training/training.md b/docs/training/training.md index 0f946c7..229b9cd 100644 --- a/docs/training/training.md +++ b/docs/training/training.md @@ -3,164 +3,357 @@ [Link](https://github.com/singnet/snet-sdk-python/blob/master/snet/sdk/training/training.py) to GitHub Entities: -1. [ModelMethodMessage](#class-modelmethodmessage) -2. [TrainingModel](#class-trainingmodel) - - [\_\_init\_\_](#__init__) - - [_invoke_model](#_invoke_model) +1. [Training](#class-training) + - [\__init__](#__init__) + - [get_model_id_object](#get_model_id_object) - [create_model](#create_model) - - [get_model_status](#get_model_status) + - [validate_model_price](#validate_model_price) + - [train_model_price](#train_model_price) - [delete_model](#delete_model) - - [update_model_access](#update_model_access) + - [get_training_metadata](#get_training_metadata) - [get_all_models](#get_all_models) + - [get_model](#get_model) + - [get_method_metadata](#get_method_metadata) + - [update_model](#update_model) + - [upload_and_validate](#upload_and_validate) + - [train_model](#train_model) + - [_call_method](#_call_method) + - [_get_training_stub](#_get_training_stub) + - [_get_auth_details](#_get_auth_details) + - [_check_method_name](#_check_method_name) + - [_check_training](#_check_training) + - [_check_dataset](#_check_dataset) + - [_get_grpc_channel](#_get_grpc_channel) + +### Class `Training` -### Class `ModelMethodMessage` - -extends: `Enum` +extends: - is extended by: - #### description -This is an `enum` that represents the available methods that can be called in the training grpc service. +`Training` is a class that is responsible for the training functionality of the service and interacting with it -#### members +#### attributes -- `CreateModel` (str): The method to create a new model. -- `GetModelStatus` (str): The method to get the status of a model. -- `UpdateModelAccess` (str): The method to update the access of a model. -- `DeleteModel` (str): The method to delete a model. -- `GetAllModels` (str): The method to get all models. +- `training_daemon` (ModuleType): The module containing the training daemon stubs. +- `training_daemon_grpc` (ModuleType): The module containing the training daemon gRPC stubs. +- `training` (ModuleType): The module containing the training stubs. +- `service_client` (ServiceClient): The `ServiceClient` instance. +- `is_enabled` (bool): Whether the training is enabled. +- `payment_strategy` (TrainingPaymentStrategy): The payment strategy used for training. -### Class `TrainingModel` +#### methods -extends: - +#### `__init__` -is extended by: - +Initializes the `Training` object. Imports the necessary modules and initializes the `TrainingPaymentStrategy` +object. Sets the `is_enabled` attribute using the `_check_training` method. -#### description +###### args: -This is a class that represents a training gRPC service. +- `service_client` (ServiceClient): The `ServiceClient` instance. +- `training_added` (bool): Whether the training is added in service proto file. Defaults to _False_. -#### attributes +###### returns: -- `training_pb2` (ModuleType): The gRPC service module. -- `training_pb2_grpc` (ModuleType): The gRPC service module. +- _None_ -#### methods +#### `get_model_id_object` -#### `__init__` +Converts the model ID from a string to an object from stub. + +###### args: -Initializes a new instance of the class. Imports gRPC service modules. +- `model_id` (str): The model ID to convert. ###### returns: -- _None_ +- The stub object for the model ID. (Any) -#### `_invoke_model` +#### `create_model` -Invokes the model by establishing a gRPC channel and generating an authorization request. +Creates a necessary request object `NewModelRequest(AuthorizationDetails, NewModel)` from stub and calls the +`create_model` grpc method using the `_call_method` method. ###### args: -- `service_client` (ServiceClient): The client object for the service. -- `msg` (ModelMethodMessage): The message containing the method to be invoked. +- `method_name` (str): The name of the service method for which we want to create a new model. +- `model_name` (str): The name of the model. +- `model_description` (str): The description of the model. Defaults to empty string. +- `is_public_accessible` (bool): Whether the model is publicly accessible. Defaults to _False_. +- `addresses_with_access` (list[str]): A list of addresses with access to the model. Defaults to empty list. ###### returns: -- A tuple containing the authorization request and the gRPC channel. (tuple[AuthorizationDetails, grpc.Channel]) +- The newly created model. (Model) + +#### `validate_model_price` + +Creates a necessary request object `AuthValidateRequest(AuthorizationDetails, model_id, training_data_link)` from +stub and calls the `validate_model_price` grpc method using the `_call_method` method. + +###### args: + +- `model_id` (str): The model ID to validate. + +###### returns: + +- Price of validating the model. (int) ###### raises: -- `ValueError`: If the scheme in the service metadata is not supported. +- `NoSuchModelException`: If the model with the specified ID does not exist. +- `GRPCException`: If the gRPC call fails. -#### `create_model` +#### `train_model_price` -Calls the `create_model` method in the gRPC training service stub to create a new model. +Creates a necessary request object `CommonRequest(AuthorizationDetails, model_id)` from stub and calls the +`train_model_price` grpc method using the `_call_method` method. ###### args: -- `service_client` (ServiceClient): The client object for the service. -- `grpc_method_name` (str): The name of the gRPC method to be invoked. -- `model_name` (str): The name of the model to be created. -- `description` (str): A description of the model. Defaults to ''. -- `training_data_link` (str): A link to the training data. Defaults to ''. -- `grpc_service_name` (str): The name of the gRPC service. Defaults to 'service'. -- `is_publicly_accessible` (bool): Whether the model is publicly accessible. Defaults to False. -- `address_list` (list[str]): A list of addresses. Defaults to None. +- `model_id` (str): The model ID to train. ###### returns: -- The response from the create model request. (Any) +- Price of training the model. (int) + +###### raises: -_Note_: Returns an exception if an error occurs during the create model request. +- `NoSuchModelException`: If the model with the specified ID does not exist. +- `GRPCException`: If the gRPC call fails. -#### `get_model_status` +#### `delete_model` -Calls the `get_model_status` method in the gRPC training service stub to get a model status. +Creates a necessary request object `CommonRequest(AuthorizationDetails, model_id)` from stub and calls the +`delete_model` grpc method using the `_call_method` method. ###### args: -- `service_client` (ServiceClient): The client object for the service. -- `model_id` (str): The ID of the model whose status to be retrieved. +- `model_id` (str): The model ID to delete. ###### returns: -- The response from the get model status request. (Any) +- Status of the model. (ModelStatus) + +###### raises: -_Note_: Returns an exception if an error occurs during the get model status request. +- `NoSuchModelException`: If the model with the specified ID does not exist. +- `GRPCException`: If the gRPC call fails. -#### `delete_model` +#### `get_training_metadata` -Calls the `delete_model` method in the gRPC training service stub to delete a model. +Calls the `get_training_metadata` grpc method using the `_call_method` method with empty request. + +###### returns: + +- Information about the training on the service. (TrainingMetadata) + +#### `get_all_models` + +Creates a necessary request object `AllModelsRequest` from stub and calls the +`get_all_models` grpc method using the `_call_method` method. Use arguments as the filters to get all models. ###### args: -- `service_client` (ServiceClient): The client object for the service. -- `model_id` (str): The ID of the model to be deleted. -- `grpc_service_name` (str): The name of the gRPC service. Defaults to 'service'. -- `grpc_method_name` (str): The name of the gRPC method to be invoked. +- `statuses` (list[ModelStatus]): Statuses by which models need to be filtered. Defaults to _None_ (no filter). +- `is_public` (bool): Whether the models are public or not. Defaults to _None_ (no filter). +- `grpc_method_name` (str): The name of the gRPC method to call. Defaults to empty string (no filter). +- `grpc_service_name` (str): The name of the gRPC service to call. Defaults to empty string (no filter). +- `model_name` (str): The name of the model. Defaults to empty string (no filter). +- `created_by_address` (str): The address of the user who created the model. Defaults to empty string (no filter). ###### returns: -- The response from the delete model request. (Any) +- List of models. (list[Model]) + +#### `get_model` -_Note_: Returns an exception if an error occurs during the delete model request. +Creates a necessary request object `CommonRequest(AuthorizationDetails, model_id)` from stub and calls the +`get_model` grpc method using the `_call_method` method. -#### `update_model_access` +###### args: -Calls the `update_model_access` method in the gRPC training service stub to update the access of a model. +- `model_id` (str): The model ID to get. + +###### returns: + +- Price of training the model. (int) + +###### raises: + +- `NoSuchModelException`: If the model with the specified ID does not exist. +- `GRPCException`: If the gRPC call fails. + +#### `get_method_metadata` + +Creates a necessary request object `MethodMetadataRequest(grpc_method_name, grpc_service_name, model_id)` from stub +and calls the `get_method_metadata` grpc method using the `_call_method` method. You can get metadata by `method_name` +or `model_id`. ###### args: -- `service_client` (ServiceClient): The client object for the service. -- `model_id` (str): The ID of the model whose access to be updated. -- `grpc_method_name` (str): The name of the gRPC method to be invoked. -- `model_name` (str): The name of the model. -- `is_punlic` (bool): Whether the model is publicly accessible. -- `description` (str): A description of the model. -- `grpc_service_name` (str): The name of the gRPC service. Defaults to 'service'. -- `address_list` (list[str]): A list of addresses. +- `method_name` (str): The name of the service method for which we want to get metadata. +- `model_id` (str): The model ID for which we want to get metadata. ###### returns: -- The response from the update model access request. (Any) +- Object with dateset requirements. (MethodMetadata) -_Note_: Returns an exception if an error occurs during the update model access request. +#### `update_model` -#### `get_all_models` +Creates a necessary request object `UpdateModelRequest` from stub and calls the +`get_model` grpc method using the `_call_method` method. + +###### args: + +- `model_id` (str): The model ID to update. +- `model_name` (str): New name of the model. (Optional) +- `description` (str): New description of the model. (Optional) +- `addresses_with_access` (list[str]): Updated list of addresses with access to the model. (Optional) + +###### returns: + +- Updated model. (Model) + +###### raises: + +- `NoSuchModelException`: If the model with the specified ID does not exist. +- `GRPCException`: If the gRPC call fails. + +#### `upload_and_validate` + +Checks dataset using the `_check_dataset` method. Sets price and model id into payment strategy. Creates a +generator that returns request objects `UploadInput` with file data one byte at a time. Calls the `upload_and_validate` +grpc method using the `_call_method` method. + +###### args: + +- `model_id` (str): The model ID to validate. +- `zip_path` (str | Path | PurePath): The path to the dataset. +- `price` (int): The price of the method call. + +###### returns: -Calls the `get_all_models` method in the gRPC training service stub to get all models. +- Status of the model. (ModelStatus) + +###### raises: + +- `NoSuchModelException`: If the model with the specified ID does not exist. +- `GRPCException`: If the gRPC call fails. + +#### `train_model` + +Sets price and model id into payment strategy. Creates a necessary request object `UpdateModelRequest` from stub +and calls the `train_model` grpc method using the `_call_method` method. + +###### args: + +- `model_id` (str): The model ID to train. +- `price` (int): The price of the method call. + +###### returns: + +- Status of the model. (ModelStatus) + +###### raises: + +- `NoSuchModelException`: If the model with the specified ID does not exist. +- `GRPCException`: If the gRPC call fails. + +#### `_call_method` + +Calls the specified gRPC training method and returns the response. + +###### args: + +- `method_name` (str): The name of the gRPC method to call. +- `request_data` (Any): The request data to pass to the gRPC method. +- `paid` (bool): Whether the method is paid or not. Defaults to _False_. + +###### returns: + +- Response from the gRPC method. (Any) + +###### raises: + +- `GRPCException`: If the gRPC call fails. + +#### `_get_training_stub` + +Creates a gRPC stub for the training with gRPC channel. + +###### args: + +- `paid` (bool): Whether the method is paid or not. Defaults to _False_. + +###### returns: + +- gRPC stub. (Any) + +#### `_get_auth_details` + +Creates a necessary request object `AuthorizationDetails` from stub with user data and signature. + +###### args: + +-`method_msg` (ModelMethodMessage): The message of the method. + +###### returns: + +- Authorization stub object. (Any) + +#### `_check_method_name` + +Checks if the specified method name is valid using `get_services_and_messages_info` method of the service client. + +###### args: + +- `method_name` (str): The name of the method. + +###### returns: + +- Service name and method name. (tuple[str, str]) + +###### raises: + +- `WrongMethodException`: If the method name is invalid. + +#### `_check_training` + +Checks if training is enabled in the service using `get_training_metadata` method. + +###### returns: + +- Whether training is enabled or not. (bool) + +#### `_check_dataset` + +Checks the dataset for compliance with the requirements (which is obtained via the `get_method_metadata` method). ###### args: -- `service_client` (ServiceClient): The client object for the service. -- `grpc_method_name` (str): The name of the gRPC method to be invoked. -- `grpc_service_name` (str): The name of the gRPC service. Defaults to 'service'. +- `model_id` (str): The model ID. +- `zip_path` (str | Path | PurePath): The path to the dataset. ###### returns: -- The response from the get all models request. (Any) +- _None_ -_Note_: Returns an exception if an error occurs during the get all models request. +###### raises: + +- `WrongDatasetException`: If the dataset is not valid. + +#### `_get_grpc_channel` + +Creates a gRPC channel for paid methods. +###### args: + +- `base_channel` (grpc.Channel): The base gRPC channel. + +###### returns: +- gRPC channel with interceptor for paid training methods. (grpc.Channel) diff --git a/docs/utils/call_utils.md b/docs/utils/call_utils.md new file mode 100644 index 0000000..f162f74 --- /dev/null +++ b/docs/utils/call_utils.md @@ -0,0 +1,24 @@ +## module: sdk.utils.call_utils + +[Link](https://github.com/singnet/snet-sdk-python/blob/master/snet/sdk/utils/call_utils.py) to GitHub + +Entities: +1. [_ClientCallDetails](#class-_clientcalldetails) +2. [create_intercept_call_func](#function-create_intercept_call_func) + +### Class `_ClientCallDetails` + +extends `grpc.ClientCallDetails`, `namedtuple` + +is extended by: - + +### Function `create_intercept_call_func` + +###### args: + +- `get_metadata_func` (callable): The function to get metadata for the call. +- `service_client` (ServiceClient): The service client to use for the call. + +###### returns: + +- The function to intercept the call. (callable) \ No newline at end of file diff --git a/docs/utils/utils.md b/docs/utils/utils.md index ee8ec12..e8d8676 100644 --- a/docs/utils/utils.md +++ b/docs/utils/utils.md @@ -10,10 +10,11 @@ Entities: 5. [is_valid_endpoint](#function-is_valid_endpoint) 6. [normalize_private_key](#function-normalize_private_key) 7. [get_address_from_private](#function-get_address_from_private) -8. [add_to_path](#class-add_to_path) -9. [find_file_by_keyword](#function-find_file_by_keyword) -10. [bytesuri_to_hash](#function-bytesuri_to_hash) -11. [safe_extract_proto](#function-safe_extract_proto) +8. [get_current_block_number](#function-get_current_block_number) +9. [add_to_path](#class-add_to_path) +10. [find_file_by_keyword](#function-find_file_by_keyword) +11. [bytesuri_to_hash](#function-bytesuri_to_hash) +12. [safe_extract_proto](#function-safe_extract_proto) ### Function `safe_address_converter` @@ -67,6 +68,7 @@ Generated files as well as .proto files are stored in the `~/.snet` directory. - `codegen_dir` (PurePath): The directory where the compiled code will be generated. - `proto_file` (str): The name of the .proto file to compile. Defaults to `None`. - `target_language` (str, optional): The target language for the compiled code. Defaults to "python". +- `add_training` (bool): Whether to include training.proto in the compilation. Defaults to False. ###### returns: @@ -116,6 +118,14 @@ Returns the wallet address from the private key. - The wallet address. (ChecksumAddress) +### Function `get_current_block_number` + +Returns the current block number in Ethereum. + +###### returns: + +- The current block number. (BlockNumber (int)) + ### Class `add_to_path` `add_to_path` class is a _**context manager**_ that temporarily adds a given path to the system's `sys.path` list. @@ -135,6 +145,7 @@ Finds a file by keyword in the current directory and subdirectories. - `directory` (AnyStr | PathLike[AnyStr]): The directory to search in. - `keyword` (AnyStr): The keyword to search for. +- `exclude` (List[AnyStr], optional): A list of strings to exclude from the search. Defaults to _None_. ###### returns: diff --git a/snet/sdk/__init__.py b/snet/sdk/__init__.py index 850e1bc..ee8b4e7 100644 --- a/snet/sdk/__init__.py +++ b/snet/sdk/__init__.py @@ -1,47 +1,42 @@ import importlib import os -from pathlib import Path import sys -from typing import Any, NewType import warnings import google.protobuf.internal.api_implementation -from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider +from google.protobuf import symbol_database as _symbol_database + +from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata with warnings.catch_warnings(): # Suppress the eth-typing package`s warnings related to some new networks - warnings.filterwarnings("ignore", "Network .* does not have a valid ChainId. eth-typing should be " - "updated with the latest networks.", UserWarning) - from snet.sdk.storage_provider.storage_provider import StorageProvider + warnings.filterwarnings( + "ignore", + "Network .* does not have a valid ChainId. eth-typing should be " + "updated with the latest networks.", + UserWarning + ) -from snet.sdk.payment_strategies.default_payment_strategy import DefaultPaymentStrategy -from snet.sdk.client_lib_generator import ClientLibGenerator + import web3 + +from snet.contracts import get_contract_object +from snet.sdk.account import Account from snet.sdk.config import Config -from snet.sdk.utils.utils import bytes32_to_str, type_converter +from snet.sdk.client_lib_generator import ClientLibGenerator +from snet.sdk.mpe.mpe_contract import MPEContract +from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider +from snet.sdk.payment_strategies.default_payment_strategy import DefaultPaymentStrategy as PaymentStrategy +from snet.sdk.service_client import ServiceClient +from snet.sdk.storage_provider.storage_provider import StorageProvider +from snet.sdk.custom_typing import ModuleName, ServiceStub +from snet.sdk.utils.utils import (bytes32_to_str, find_file_by_keyword, + type_converter) google.protobuf.internal.api_implementation.Type = lambda: 'python' - -from google.protobuf import symbol_database as _symbol_database - _sym_db = _symbol_database.Default() _sym_db.RegisterMessage = lambda x: None -import web3 - -from snet.sdk.service_client import ServiceClient -from snet.sdk.account import Account -from snet.sdk.mpe.mpe_contract import MPEContract - -from snet.contracts import get_contract_object - -from snet.sdk.storage_provider.service_metadata import mpe_service_metadata_from_json -from snet.sdk.utils.ipfs_utils import get_from_ipfs_and_checkhash -from snet.sdk.utils.utils import find_file_by_keyword - -ModuleName = NewType('ModuleName', str) -ServiceStub = NewType('ServiceStub', Any) - class SnetSDK: """Base Snet SDK""" @@ -54,116 +49,156 @@ def __init__(self, sdk_config: Config, metadata_provider=None): eth_rpc_endpoint = self._sdk_config["eth_rpc_endpoint"] eth_rpc_request_kwargs = self._sdk_config.get("eth_rpc_request_kwargs") - provider = web3.HTTPProvider(endpoint_uri=eth_rpc_endpoint, request_kwargs=eth_rpc_request_kwargs) + provider = web3.HTTPProvider(endpoint_uri=eth_rpc_endpoint, + request_kwargs=eth_rpc_request_kwargs) self.web3 = web3.Web3(provider) - # Get MPE contract address from config if specified; mostly for local testing - _mpe_contract_address = self._sdk_config.get("mpe_contract_address", None) + # Get MPE contract address from config if specified; + # mostly for local testing + _mpe_contract_address = self._sdk_config.get("mpe_contract_address", + None) if _mpe_contract_address is None: self.mpe_contract = MPEContract(self.web3) else: self.mpe_contract = MPEContract(self.web3, _mpe_contract_address) - # Get Registry contract address from config if specified; mostly for local testing - _registry_contract_address = self._sdk_config.get("registry_contract_address", None) + # Get Registry contract address from config if specified; + # mostly for local testing + _registry_contract_address = self._sdk_config.get( + "registry_contract_address", + None + ) if _registry_contract_address is None: self.registry_contract = get_contract_object(self.web3, "Registry") else: - self.registry_contract = get_contract_object(self.web3, "Registry", _registry_contract_address) + self.registry_contract = get_contract_object( + self.web3, + "Registry", + _registry_contract_address + ) if self._metadata_provider is None: - self._metadata_provider = StorageProvider(self._sdk_config, self.registry_contract) + self._metadata_provider = StorageProvider(self._sdk_config, + self.registry_contract) self.account = Account(self.web3, sdk_config, self.mpe_contract) - self.payment_channel_provider = PaymentChannelProvider(self.web3, self.mpe_contract) - - def create_service_client(self, org_id: str, service_id: str, group_name=None, - payment_channel_management_strategy=None, + self.payment_channel_provider = PaymentChannelProvider( + self.web3, + self.mpe_contract + ) + + def create_service_client(self, + org_id: str, + service_id: str, + group_name=None, + payment_strategy=None, free_call_auth_token_bin=None, free_call_token_expiry_block=None, email=None, options=None, - concurrent_calls=1): + concurrent_calls: int = 1): - # Create and instance of the Config object, so we can create an instance of ClientLibGenerator - lib_generator = ClientLibGenerator(self._metadata_provider, org_id, service_id) + # Create and instance of the Config object, + # so we can create an instance of ClientLibGenerator + self.lib_generator = ClientLibGenerator(self._metadata_provider, + org_id, service_id) # Download the proto file and generate stubs if needed force_update = self._sdk_config.get('force_update', False) if force_update: - lib_generator.generate_client_library() + self.lib_generator.generate_client_library() else: - path_to_pb_files = self.get_path_to_pb_files(org_id, service_id) - pb_2_file_name = find_file_by_keyword(path_to_pb_files, keyword="pb2.py") - pb_2_grpc_file_name = find_file_by_keyword(path_to_pb_files, keyword="pb2_grpc.py") + path_to_pb_files = self.lib_generator.protodir + pb_2_file_name = find_file_by_keyword(path_to_pb_files, + keyword="pb2.py", + exclude=["training"]) + pb_2_grpc_file_name = find_file_by_keyword(path_to_pb_files, + keyword="pb2_grpc.py", + exclude=["training"]) if not pb_2_file_name or not pb_2_grpc_file_name: - lib_generator.generate_client_library() - - if payment_channel_management_strategy is None: - payment_channel_management_strategy = DefaultPaymentStrategy(concurrent_calls) + print("Generating client library...") + self.lib_generator.generate_client_library() + + if payment_strategy is None: + payment_strategy = PaymentStrategy( + concurrent_calls=concurrent_calls + ) if options is None: options = dict() - options['free_call_auth_token-bin'] = bytes.fromhex(free_call_auth_token_bin) if\ + options['free_call_auth_token-bin'] = ( + bytes.fromhex(free_call_auth_token_bin) if free_call_token_expiry_block else "" - options['free-call-token-expiry-block'] = free_call_token_expiry_block if\ - free_call_token_expiry_block else 0 + ) + options['free-call-token-expiry-block'] = ( + free_call_token_expiry_block if free_call_token_expiry_block else 0 + ) options['email'] = email if email else "" options['concurrency'] = self._sdk_config.get("concurrency", True) - service_metadata = self._metadata_provider.enhance_service_metadata(org_id, service_id) + service_metadata = self._metadata_provider.enhance_service_metadata( + org_id, service_id + ) group = self._get_service_group_details(service_metadata, group_name) - strategy = payment_channel_management_strategy - - service_stub = self.get_service_stub(org_id, service_id) - - pb2_module = self.get_module_by_keyword(org_id, service_id, keyword="pb2.py") - - _service_client = ServiceClient(org_id, service_id, service_metadata, group, service_stub, strategy, - options, self.mpe_contract, self.account, self.web3, pb2_module, self.payment_channel_provider) + + service_stubs = self.get_service_stub() + + pb2_module = self.get_module_by_keyword(keyword="pb2.py") + _service_client = ServiceClient(org_id, service_id, service_metadata, + group, service_stubs, payment_strategy, + options, self.mpe_contract, + self.account, self.web3, pb2_module, + self.payment_channel_provider, + self.lib_generator.protodir, + self.lib_generator.training_added()) return _service_client - def get_service_stub(self, org_id: str, service_id: str) -> ServiceStub: - path_to_pb_files = self.get_path_to_pb_files(org_id, service_id) - module_name = self.get_module_by_keyword(org_id, service_id, keyword="pb2_grpc.py") - package_path = os.path.dirname(path_to_pb_files) - sys.path.append(package_path) + def get_service_stub(self) -> list[ServiceStub]: + path_to_pb_files = str(self.lib_generator.protodir) + module_name = self.get_module_by_keyword(keyword="pb2_grpc.py") + sys.path.append(path_to_pb_files) try: grpc_file = importlib.import_module(module_name) properties_and_methods_of_grpc_file = dir(grpc_file) - class_stub = [elem for elem in properties_and_methods_of_grpc_file if 'Stub' in elem][0] - service_stub = getattr(grpc_file, class_stub) - return ServiceStub(service_stub) + service_stubs = [] + for elem in properties_and_methods_of_grpc_file: + if 'Stub' in elem: + service_stubs.append(getattr(grpc_file, elem)) + return [ServiceStub(service_stub) for service_stub in service_stubs] + except Exception as e: raise Exception(f"Error importing module: {e}") - def get_path_to_pb_files(self, org_id: str, service_id: str) -> str: - client_libraries_base_dir_path = Path("~").expanduser().joinpath(".snet") - path_to_pb_files = f"{client_libraries_base_dir_path}/{org_id}/{service_id}/python/" - return path_to_pb_files - - def get_module_by_keyword(self, org_id: str, service_id: str, keyword: str) -> ModuleName: - path_to_pb_files = self.get_path_to_pb_files(org_id, service_id) - file_name = find_file_by_keyword(path_to_pb_files, keyword) + def get_module_by_keyword(self, keyword: str) -> ModuleName: + path_to_pb_files = self.lib_generator.protodir + file_name = find_file_by_keyword(path_to_pb_files, + keyword, + exclude=["training"]) module_name = os.path.splitext(file_name)[0] return ModuleName(module_name) def get_service_metadata(self, org_id, service_id): - return self._metadata_provider.fetch_service_metadata(org_id, service_id) + return self._metadata_provider.fetch_service_metadata(org_id, + service_id) - def _get_first_group(self, service_metadata): + def _get_first_group(self, service_metadata: MPEServiceMetadata) -> dict: return service_metadata['groups'][0] - def _get_group_by_group_name(self, service_metadata, group_name): + def _get_group_by_group_name(self, + service_metadata: MPEServiceMetadata, + group_name: str) -> dict: for group in service_metadata['groups']: if group['group_name'] == group_name: return group return {} - def _get_service_group_details(self, service_metadata, group_name): + def _get_service_group_details(self, + service_metadata: MPEServiceMetadata, + group_name: str) -> dict: if len(service_metadata['groups']) == 0: - raise Exception("No Groups found for given service, Please add group to the service") + raise Exception("No Groups found for given service, " + "Please add group to the service") if group_name is None: return self._get_first_group(service_metadata) @@ -178,11 +213,11 @@ def get_organization_list(self) -> list: return organization_list def get_services_list(self, org_id: str) -> list: - found, org_service_list = self.registry_contract.functions.listServicesForOrganization( - type_converter("bytes32")(org_id) - ).call() + found, org_service_list = ( + self.registry_contract.functions + .listServicesForOrganization(type_converter("bytes32")(org_id)) + .call()) if not found: raise Exception(f"Organization with id={org_id} doesn't exist!") org_service_list = list(map(bytes32_to_str, org_service_list)) return org_service_list - diff --git a/snet/sdk/account.py b/snet/sdk/account.py index df71493..745bf3b 100644 --- a/snet/sdk/account.py +++ b/snet/sdk/account.py @@ -1,14 +1,22 @@ import json -from snet.sdk.utils.utils import get_address_from_private, normalize_private_key +import web3 + from snet.contracts import get_contract_object +from snet.sdk.config import Config +from snet.sdk.mpe.mpe_contract import MPEContract +from snet.sdk.utils.utils import (get_address_from_private, + normalize_private_key) DEFAULT_GAS = 300000 TRANSACTION_TIMEOUT = 500 class TransactionError(Exception): - """Raised when an Ethereum transaction receipt has a status of 0. Can provide a custom message. Optionally includes receipt""" + """ + Raised when an Ethereum transaction receipt has a status of 0. + Can provide a custom message. Optionally includes receipt + """ def __init__(self, message, receipt=None): super().__init__(message) @@ -20,11 +28,14 @@ def __str__(self): class Account: - def __init__(self, w3, config, mpe_contract): - self.config = config - self.web3 = w3 - self.mpe_contract = mpe_contract - _token_contract_address = self.config.get("token_contract_address", None) + def __init__(self, w3: web3.Web3, config: Config, + mpe_contract: MPEContract): + self.config: Config = config + self.web3: web3.Web3 = w3 + self.mpe_contract: MPEContract = mpe_contract + _token_contract_address: str | None = self.config.get( + "token_contract_address", None + ) if _token_contract_address is None: self.token_contract = get_contract_object( self.web3, "SingularityNetToken") @@ -32,13 +43,12 @@ def __init__(self, w3, config, mpe_contract): self.token_contract = get_contract_object( self.web3, "SingularityNetToken", _token_contract_address) - private_key = config.get("private_key", None) - signer_private_key = config.get("signer_private_key", None) - if private_key is not None: - self.private_key = normalize_private_key(config["private_key"]) - if signer_private_key is not None: + if config.get("private_key") is not None: + self.private_key = normalize_private_key(config.get("private_key")) + if config.get("signer_private_key") is not None: self.signer_private_key = normalize_private_key( - config["signer_private_key"]) + config.get("signer_private_key") + ) else: self.signer_private_key = self.private_key self.address = get_address_from_private(self.private_key) @@ -73,17 +83,21 @@ def _send_signed_transaction(self, contract_fn, *args): }) signed_txn = self.web3.eth.account.sign_transaction( transaction, private_key=self.private_key) - return self.web3.to_hex(self.web3.eth.send_raw_transaction(signed_txn.rawTransaction)) + return self.web3.to_hex( + self.web3.eth.send_raw_transaction(signed_txn.rawTransaction) + ) def send_transaction(self, contract_fn, *args): txn_hash = self._send_signed_transaction(contract_fn, *args) - return self.web3.eth.wait_for_transaction_receipt(txn_hash, TRANSACTION_TIMEOUT) + return self.web3.eth.wait_for_transaction_receipt(txn_hash, + TRANSACTION_TIMEOUT) def _parse_receipt(self, receipt, event, encoder=json.JSONEncoder): if receipt.status == 0: raise TransactionError("Transaction failed", receipt) else: - return json.dumps(dict(event().processReceipt(receipt)[0]["args"]), cls=encoder) + return json.dumps(dict(event().processReceipt(receipt)[0]["args"]), + cls=encoder) def escrow_balance(self): return self.mpe_contract.balance(self.address) @@ -95,8 +109,12 @@ def deposit_to_escrow_account(self, amount_in_cogs): return self.mpe_contract.deposit(self, amount_in_cogs) def approve_transfer(self, amount_in_cogs): - return self.send_transaction(self.token_contract.functions.approve, self.mpe_contract.contract.address, + return self.send_transaction(self.token_contract.functions.approve, + self.mpe_contract.contract.address, amount_in_cogs) def allowance(self): - return self.token_contract.functions.allowance(self.address, self.mpe_contract.contract.address).call() + return self.token_contract.functions.allowance( + self.address, + self.mpe_contract.contract.address + ).call() diff --git a/snet/sdk/client_lib_generator.py b/snet/sdk/client_lib_generator.py index 3933ba1..934f09f 100644 --- a/snet/sdk/client_lib_generator.py +++ b/snet/sdk/client_lib_generator.py @@ -1,50 +1,70 @@ import os -from pathlib import Path, PurePath +from pathlib import Path -from snet.sdk import StorageProvider -from snet.sdk.utils import ipfs_utils -from snet.sdk.utils.utils import compile_proto, type_converter +from snet.sdk.storage_provider.storage_provider import StorageProvider +from snet.sdk.utils.utils import compile_proto class ClientLibGenerator: + def __init__(self, metadata_provider: StorageProvider, org_id: str, + service_id: str, protodir: Path | None = None): + self._metadata_provider: StorageProvider = metadata_provider + self.org_id: str = org_id + self.service_id: str = service_id + self.language: str = "python" + self.protodir: Path = (protodir if protodir else + Path.home().joinpath(".snet")) + self.generate_directories_by_params() - def __init__(self, metadata_provider, org_id, service_id): - self._metadata_provider = metadata_provider - self.org_id = org_id - self.service_id = service_id - self.language = "python" - self.protodir = Path("~").expanduser().joinpath(".snet") - - def generate_client_library(self): - - if os.path.isabs(self.protodir): - client_libraries_base_dir_path = PurePath(self.protodir) - else: - cur_dir_path = PurePath(os.getcwd()) - client_libraries_base_dir_path = cur_dir_path.joinpath(self.protodir) - - os.makedirs(client_libraries_base_dir_path, exist_ok=True) - - # Create service client libraries path - library_language = self.language - library_org_id = self.org_id - library_service_id = self.service_id + def generate_client_library(self) -> None: + try: + self.receive_proto_files() + compilation_result = compile_proto(entry_path=self.protodir, + codegen_dir=self.protodir, + target_language=self.language, + add_training=self.training_added()) + if compilation_result: + print(f'client libraries for service with id "{self.service_id}" ' + f'in org with id "{self.org_id}" ' + f'generated at {self.protodir}') + except Exception as e: + print(str(e)) - library_dir_path = client_libraries_base_dir_path.joinpath(library_org_id, - library_service_id, - library_language) + def generate_directories_by_params(self) -> None: + if not self.protodir.is_absolute(): + self.protodir = Path.cwd().joinpath(self.protodir) + self.create_service_client_libraries_path() - try: - metadata = self._metadata_provider.fetch_service_metadata(self.org_id, self.service_id) - service_api_source = metadata.get("service_api_source") or metadata.get("model_ipfs_hash") + def create_service_client_libraries_path(self) -> None: + self.protodir = self.protodir.joinpath(self.org_id, + self.service_id, + self.language) + self.protodir.mkdir(parents=True, exist_ok=True) - # Receive proto files - self._metadata_provider.fetch_and_extract_proto(service_api_source, library_dir_path) + def receive_proto_files(self) -> None: + metadata = self._metadata_provider.fetch_service_metadata( + org_id=self.org_id, + service_id=self.service_id + ) + service_api_source = (metadata.get("service_api_source") or + metadata.get("model_ipfs_hash")) - # Compile proto files - compile_proto(Path(library_dir_path), library_dir_path, target_language=self.language) + # Receive proto files + if self.protodir.exists(): + self._metadata_provider.fetch_and_extract_proto( + service_api_source, + self.protodir + ) + else: + raise Exception("Directory for storing proto files is not found") - print(f'client libraries for service with id "{library_service_id}" in org with id "{library_org_id}" ' - f'generated at {library_dir_path}') - except Exception as e: - print(e) + def training_added(self) -> bool: + files = os.listdir(self.protodir) + for file in files: + if ".proto" not in file: + continue + with open(self.protodir.joinpath(file), "r") as f: + proto_text = f.read() + if 'import "training.proto";' in proto_text: + return True + return False diff --git a/snet/sdk/concurrency_manager.py b/snet/sdk/concurrency_manager.py index 6c2a29a..5d86495 100644 --- a/snet/sdk/concurrency_manager.py +++ b/snet/sdk/concurrency_manager.py @@ -3,18 +3,19 @@ import grpc import web3 +from snet.sdk.service_client import ServiceClient from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path class ConcurrencyManager: - def __init__(self, concurrent_calls): - self.__concurrent_calls = concurrent_calls - self.__token = '' - self.__planned_amount = 0 - self.__used_amount = 0 + def __init__(self, concurrent_calls: int): + self.__concurrent_calls: int = concurrent_calls + self.__token: str = '' + self.__planned_amount: int = 0 + self.__used_amount: int = 0 @property - def concurrent_calls(self): + def concurrent_calls(self) -> int: return self.__concurrent_calls def get_token(self, service_client, channel, service_call_price): @@ -24,7 +25,7 @@ def get_token(self, service_client, channel, service_call_price): self.__token = self.__get_token(service_client, channel, service_call_price, new_token=True) return self.__token - def __get_token(self, service_client, channel, service_call_price, new_token=False): + def __get_token(self, service_client: ServiceClient, channel, service_call_price, new_token=False): if not new_token: amount = channel.state["last_signed_amount"] if amount != 0: @@ -46,13 +47,13 @@ def __get_token(self, service_client, channel, service_call_price, new_token=Fal self.__planned_amount = token_reply.planned_amount return token_reply.token - def __get_stub_for_get_token(self, service_client): + def __get_stub_for_get_token(self, service_client: ServiceClient): grpc_channel = service_client.get_grpc_base_channel() with add_to_path(str(RESOURCES_PATH.joinpath("proto"))): token_service_pb2_grpc = importlib.import_module("token_service_pb2_grpc") return token_service_pb2_grpc.TokenServiceStub(grpc_channel) - def __get_token_for_amount(self, service_client, channel, amount): + def __get_token_for_amount(self, service_client: ServiceClient, channel, amount): nonce = channel.state["nonce"] stub = self.__get_stub_for_get_token(service_client) with add_to_path(str(RESOURCES_PATH.joinpath("proto"))): diff --git a/snet/sdk/config.py b/snet/sdk/config.py index e8f8b04..938b01d 100644 --- a/snet/sdk/config.py +++ b/snet/sdk/config.py @@ -16,7 +16,8 @@ def __init__(self, "private_key": private_key, "eth_rpc_endpoint": eth_rpc_endpoint, "wallet_index": wallet_index, - "ipfs_endpoint": ipfs_endpoint if ipfs_endpoint else "/dns/ipfs.singularitynet.io/tcp/80/", + "ipfs_endpoint": (ipfs_endpoint if ipfs_endpoint + else "/dns/ipfs.singularitynet.io/tcp/80/"), "concurrency": concurrency, "force_update": force_update, "mpe_contract_address": mpe_contract_address, @@ -34,4 +35,3 @@ def get(self, key, default=None): def get_ipfs_endpoint(self): return self["ipfs_endpoint"] - diff --git a/snet/sdk/custom_typing.py b/snet/sdk/custom_typing.py new file mode 100644 index 0000000..e6ee8c2 --- /dev/null +++ b/snet/sdk/custom_typing.py @@ -0,0 +1,5 @@ +from typing import Any, NewType + + +ModuleName = NewType('ModuleName', str) +ServiceStub = NewType('ServiceStub', Any) diff --git a/snet/sdk/payment_strategies/default_payment_strategy.py b/snet/sdk/payment_strategies/default_payment_strategy.py index d8582d7..c66f711 100644 --- a/snet/sdk/payment_strategies/default_payment_strategy.py +++ b/snet/sdk/payment_strategies/default_payment_strategy.py @@ -7,7 +7,7 @@ class DefaultPaymentStrategy(PaymentStrategy): - def __init__(self, concurrent_calls=1): + def __init__(self, concurrent_calls: int = 1): self.concurrent_calls = concurrent_calls self.concurrencyManager = ConcurrencyManager(concurrent_calls) self.channel = None diff --git a/snet/sdk/payment_strategies/payment_strategy.py b/snet/sdk/payment_strategies/payment_strategy.py index d840eb9..2646ba1 100644 --- a/snet/sdk/payment_strategies/payment_strategy.py +++ b/snet/sdk/payment_strategies/payment_strategy.py @@ -1,7 +1,7 @@ class PaymentStrategy(object): - def get_payment_metadata(self,service_client): + def get_payment_metadata(self, service_client): pass - def get_price(self,service_client): + def get_price(self, service_client): pass diff --git a/snet/sdk/payment_strategies/prepaid_payment_strategy.py b/snet/sdk/payment_strategies/prepaid_payment_strategy.py index aaaf3b3..3c6f016 100644 --- a/snet/sdk/payment_strategies/prepaid_payment_strategy.py +++ b/snet/sdk/payment_strategies/prepaid_payment_strategy.py @@ -1,9 +1,11 @@ +from snet.sdk.concurrency_manager import ConcurrencyManager from snet.sdk.payment_strategies.payment_strategy import PaymentStrategy class PrePaidPaymentStrategy(PaymentStrategy): - def __init__(self, concurrency_manager, block_offset=240, call_allowance=1): + def __init__(self, concurrency_manager: ConcurrencyManager, + block_offset: int = 240, call_allowance: int = 1): self.concurrency_manager = concurrency_manager self.block_offset = block_offset self.call_allowance = call_allowance diff --git a/snet/sdk/payment_strategies/training_payment_strategy.py b/snet/sdk/payment_strategies/training_payment_strategy.py new file mode 100644 index 0000000..331802d --- /dev/null +++ b/snet/sdk/payment_strategies/training_payment_strategy.py @@ -0,0 +1,46 @@ +import web3 + +from snet.sdk.payment_strategies.paidcall_payment_strategy import PaidCallPaymentStrategy + + +class TrainingPaymentStrategy(PaidCallPaymentStrategy): + def __init__(self): + super().__init__() + self._call_price = -1 + self._train_model_id = "" + + def get_price(self, service_client=None) -> int: + if self._call_price == -1: + raise Exception("Training call price not set") + return self._call_price + + def set_price(self, call_price: int) -> None: + self._call_price = call_price + + def get_model_id(self): + return self._train_model_id + + def set_model_id(self, model_id: str): + self._train_model_id = model_id + + def get_payment_metadata(self, service_client) -> list[tuple[str, str]]: + channel = self.select_channel(service_client) + amount = channel.state["last_signed_amount"] + int(self.get_price(service_client)) + message = web3.Web3.solidity_keccak( + ["string", "address", "uint256", "uint256", "uint256"], + ["__MPE_claim_message", service_client.mpe_address, channel.channel_id, + channel.state["nonce"], + amount] + ) + signature = service_client.generate_signature(message) + + metadata = [ + ("snet-payment-type", "train-call"), + ("snet-payment-channel-id", str(channel.channel_id)), + ("snet-payment-channel-nonce", str(channel.state["nonce"])), + ("snet-payment-channel-amount", str(amount)), + ("snet-train-model-id", self.get_model_id()), + ("snet-payment-channel-signature-bin", signature) + ] + + return metadata diff --git a/snet/sdk/resources/proto/merckledag.proto b/snet/sdk/resources/proto/merckledag.proto deleted file mode 100644 index 5af078a..0000000 --- a/snet/sdk/resources/proto/merckledag.proto +++ /dev/null @@ -1,17 +0,0 @@ -syntax = "proto2"; -// An IPFS MerkleDAG Link -message MerkleLink { - required bytes Hash = 1; // multihash of the target object - required string Name = 2; // utf string name - required uint64 Tsize = 3; // cumulative size of target object - - // user extensions start at 50 -} - -// An IPFS MerkleDAG Node -message MerkleNode { - repeated MerkleLink Links = 2; // refs to other objects - required bytes Data = 1; // opaque user data - - // user extensions start at 50 -} \ No newline at end of file diff --git a/snet/sdk/resources/proto/merckledag_pb2.py b/snet/sdk/resources/proto/merckledag_pb2.py deleted file mode 100644 index c47ebb5..0000000 --- a/snet/sdk/resources/proto/merckledag_pb2.py +++ /dev/null @@ -1,27 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: merckledag.proto -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10merckledag.proto\"7\n\nMerkleLink\x12\x0c\n\x04Hash\x18\x01 \x02(\x0c\x12\x0c\n\x04Name\x18\x02 \x02(\t\x12\r\n\x05Tsize\x18\x03 \x02(\x04\"6\n\nMerkleNode\x12\x1a\n\x05Links\x18\x02 \x03(\x0b\x32\x0b.MerkleLink\x12\x0c\n\x04\x44\x61ta\x18\x01 \x02(\x0c') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'merckledag_pb2', _globals) -if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals['_MERKLELINK']._serialized_start=20 - _globals['_MERKLELINK']._serialized_end=75 - _globals['_MERKLENODE']._serialized_start=77 - _globals['_MERKLENODE']._serialized_end=131 -# @@protoc_insertion_point(module_scope) diff --git a/snet/sdk/resources/proto/merckledag_pb2_grpc.py b/snet/sdk/resources/proto/merckledag_pb2_grpc.py deleted file mode 100644 index 2daafff..0000000 --- a/snet/sdk/resources/proto/merckledag_pb2_grpc.py +++ /dev/null @@ -1,4 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - diff --git a/snet/sdk/resources/proto/training.proto b/snet/sdk/resources/proto/training.proto index 86b3074..8582c66 100644 --- a/snet/sdk/resources/proto/training.proto +++ b/snet/sdk/resources/proto/training.proto @@ -1,112 +1,128 @@ syntax = "proto3"; -import "google/protobuf/descriptor.proto"; +import "google/protobuf/descriptor.proto"; // Required for indicators to work package training; -option go_package = "../training"; -//Please note that the AI developers need to provide a server implementation of the gprc server of this proto. -message ModelDetails { - //This Id will be generated when you invoke the create_model method and hence doesnt need to be filled when you - //invoke the create model - string model_id = 1; - //define the training method name - string grpc_method_name = 2; - //define the grpc service name , under which the method is defined - string grpc_service_name = 3; - string description = 4; +option go_package = "github.com/singnet/snet-daemon/v5/training"; - string status = 6; - string updated_date = 7; - //List of all the addresses that will have access to this model - repeated string address_list = 8; - // this is optional - string training_data_link = 9; - string model_name = 10; +// Methods that the service provider must implement +service Model { + // Free + // Can pass the address of the model creator + rpc create_model(NewModel) returns (ModelID) {} - string organization_id = 11; - string service_id = 12 ; - string group_id = 13; + // Free + rpc validate_model_price(ValidateRequest) returns (PriceInBaseUnit) {} - //set this to true if you want your model to be used by other AI consumers - bool is_publicly_accessible = 14; + // Paid + rpc upload_and_validate(stream UploadInput) returns (StatusResponse) {} -} + // Paid + rpc validate_model(ValidateRequest) returns (StatusResponse) {} -message AuthorizationDetails { - uint64 current_block = 1; - //Signer can fill in any message here - string message = 2; - //signature of the following message: - //("user specified message", user_address, current_block_number) - bytes signature = 3; - string signer_address = 4; + // Free, one signature for both train_model_price & train_model methods + rpc train_model_price(ModelID) returns (PriceInBaseUnit) {} -} + // Paid + rpc train_model(ModelID) returns (StatusResponse) {} -enum Status { - CREATED = 0; - IN_PROGRESS = 1; - ERRORED = 2; - COMPLETED = 3; - DELETED = 4; -} + // Free + rpc delete_model(ModelID) returns (StatusResponse) { + // After model deletion, the status becomes DELETED in etcd + } -message CreateModelRequest { - AuthorizationDetails authorization = 1; - ModelDetails model_details = 2; + // Free + rpc get_model_status(ModelID) returns (StatusResponse) {} } -//the signer address will get to know all the models associated with this address. -message AccessibleModelsRequest { - string grpc_method_name = 1; - string grpc_service_name = 2; - AuthorizationDetails authorization = 3; -} +message ModelResponse { + string model_id = 1; + Status status = 2; + string created_date = 3; + string updated_date = 4; + string name = 5; + string description = 6; + string grpc_method_name = 7; + string grpc_service_name = 8; -message AccessibleModelsResponse { - repeated ModelDetails list_of_models = 1; -} + // List of all addresses that will have access to this model + repeated string address_list = 9; + + // Access to the model is granted only for use and viewing + bool is_public = 10; + + string training_data_link = 11; -message ModelDetailsRequest { - ModelDetails model_details = 1 ; - AuthorizationDetails authorization = 2; + string created_by_address = 12; + string updated_by_address = 13; } -//helps determine which service end point to call for model training -//format is of type "packageName/serviceName/MethodName", Example :"/example_service.Calculator/estimate_add" -//Daemon will invoke the model training end point , when the below method option is specified -message TrainingMethodOption { - string trainingMethodIndicator = 1; +// Used as input for new_model requests +// The service provider decides whether to use these fields; returning model_id is mandatory +message NewModel { + string name = 1; + string description = 2; + string grpc_method_name = 3; + string grpc_service_name = 4; + + // List of all addresses that will have access to this model + repeated string address_list = 5; + + // Set this to true if you want your model to be accessible by other AI consumers + bool is_public = 6; + + // These parameters will be passed by the daemon + string organization_id = 7; + string service_id = 8; + string group_id = 9; } -extend google.protobuf.MethodOptions { - TrainingMethodOption my_method_option = 9999197; +// This structure must be used by the service provider +message ModelID { + string model_id = 1; } -message UpdateModelRequest { - ModelDetails update_model_details = 1 ; - AuthorizationDetails authorization = 2; +// This structure must be used by the service provider +// Used in the train_model_price method to get the training/validation price +message PriceInBaseUnit { + uint64 price = 1; // cogs, weis, afet, aasi, etc. } +enum Status { + CREATED = 0; + VALIDATING = 1; + VALIDATED = 2; + TRAINING = 3; + READY_TO_USE = 4; // After training is completed + ERRORED = 5; + DELETED = 6; +} -message ModelDetailsResponse { +message StatusResponse { Status status = 1; - ModelDetails model_details = 2; - } -service Model { - - // The AI developer needs to Implement this service and Daemon will call these - // There will be no cost borne by the consumer in calling these methods, - // Pricing will apply when you actually call the training methods defined. - // AI consumer will call all these methods - rpc create_model(CreateModelRequest) returns (ModelDetailsResponse) {} - rpc delete_model(UpdateModelRequest) returns (ModelDetailsResponse) {} - rpc get_model_status(ModelDetailsRequest) returns (ModelDetailsResponse) {} - - // Daemon will implement , however the AI developer should skip implementing these and just provide dummy code. - rpc update_model_access(UpdateModelRequest) returns (ModelDetailsResponse) {} - rpc get_all_models(AccessibleModelsRequest) returns (AccessibleModelsResponse) {} +message UploadInput { + string model_id = 1; + bytes data = 2; + string file_name = 3; + uint64 file_size = 4; // in bytes + uint64 batch_size = 5; + uint64 batch_number = 6; + uint64 batch_count = 7; +} +message ValidateRequest { + string model_id = 2; + string training_data_link = 3; +} -} \ No newline at end of file +extend google.protobuf.MethodOptions { + string default_model_id = 50001; + uint64 max_models_per_user = 50002; // max models per method & user + uint64 dataset_max_size_mb = 50003; // max size of dataset + uint64 dataset_max_count_files = 50004; // maximum number of files in the dataset + uint64 dataset_max_size_single_file_mb = 50005; // maximum size of a single file in the dataset + string dataset_files_type = 50006; // allowed files types in dataset. string with array or single value, example: jpg, png, mp3 + string dataset_type = 50007; // string with array or single value, example: zip, tar.gz, tar + string dataset_description = 50008; // additional free-form requirements +} diff --git a/snet/sdk/resources/proto/training/training.proto b/snet/sdk/resources/proto/training/training.proto new file mode 100644 index 0000000..8582c66 --- /dev/null +++ b/snet/sdk/resources/proto/training/training.proto @@ -0,0 +1,128 @@ +syntax = "proto3"; +import "google/protobuf/descriptor.proto"; // Required for indicators to work +package training; +option go_package = "github.com/singnet/snet-daemon/v5/training"; + +// Methods that the service provider must implement +service Model { + + // Free + // Can pass the address of the model creator + rpc create_model(NewModel) returns (ModelID) {} + + // Free + rpc validate_model_price(ValidateRequest) returns (PriceInBaseUnit) {} + + // Paid + rpc upload_and_validate(stream UploadInput) returns (StatusResponse) {} + + // Paid + rpc validate_model(ValidateRequest) returns (StatusResponse) {} + + // Free, one signature for both train_model_price & train_model methods + rpc train_model_price(ModelID) returns (PriceInBaseUnit) {} + + // Paid + rpc train_model(ModelID) returns (StatusResponse) {} + + // Free + rpc delete_model(ModelID) returns (StatusResponse) { + // After model deletion, the status becomes DELETED in etcd + } + + // Free + rpc get_model_status(ModelID) returns (StatusResponse) {} +} + +message ModelResponse { + string model_id = 1; + Status status = 2; + string created_date = 3; + string updated_date = 4; + string name = 5; + string description = 6; + string grpc_method_name = 7; + string grpc_service_name = 8; + + // List of all addresses that will have access to this model + repeated string address_list = 9; + + // Access to the model is granted only for use and viewing + bool is_public = 10; + + string training_data_link = 11; + + string created_by_address = 12; + string updated_by_address = 13; +} + +// Used as input for new_model requests +// The service provider decides whether to use these fields; returning model_id is mandatory +message NewModel { + string name = 1; + string description = 2; + string grpc_method_name = 3; + string grpc_service_name = 4; + + // List of all addresses that will have access to this model + repeated string address_list = 5; + + // Set this to true if you want your model to be accessible by other AI consumers + bool is_public = 6; + + // These parameters will be passed by the daemon + string organization_id = 7; + string service_id = 8; + string group_id = 9; +} + +// This structure must be used by the service provider +message ModelID { + string model_id = 1; +} + +// This structure must be used by the service provider +// Used in the train_model_price method to get the training/validation price +message PriceInBaseUnit { + uint64 price = 1; // cogs, weis, afet, aasi, etc. +} + +enum Status { + CREATED = 0; + VALIDATING = 1; + VALIDATED = 2; + TRAINING = 3; + READY_TO_USE = 4; // After training is completed + ERRORED = 5; + DELETED = 6; +} + +message StatusResponse { + Status status = 1; +} + +message UploadInput { + string model_id = 1; + bytes data = 2; + string file_name = 3; + uint64 file_size = 4; // in bytes + uint64 batch_size = 5; + uint64 batch_number = 6; + uint64 batch_count = 7; +} + +message ValidateRequest { + string model_id = 2; + string training_data_link = 3; +} + +extend google.protobuf.MethodOptions { + string default_model_id = 50001; + uint64 max_models_per_user = 50002; // max models per method & user + uint64 dataset_max_size_mb = 50003; // max size of dataset + uint64 dataset_max_count_files = 50004; // maximum number of files in the dataset + uint64 dataset_max_size_single_file_mb = 50005; // maximum size of a single file in the dataset + string dataset_files_type = 50006; // allowed files types in dataset. string with array or single value, example: jpg, png, mp3 + string dataset_type = 50007; // string with array or single value, example: zip, tar.gz, tar + string dataset_description = 50008; // additional free-form requirements +} diff --git a/snet/sdk/resources/proto/training_daemon.proto b/snet/sdk/resources/proto/training_daemon.proto new file mode 100644 index 0000000..74c33d2 --- /dev/null +++ b/snet/sdk/resources/proto/training_daemon.proto @@ -0,0 +1,127 @@ +syntax = "proto3"; +package training; + +import "google/protobuf/descriptor.proto"; // Required for indicators to work +import "google/protobuf/struct.proto"; // Required for google.protobuf.ListValue +import "training.proto"; +import "google/protobuf/empty.proto"; +option go_package = "github.com/singnet/snet-daemon/v5/training"; + +message AuthorizationDetails { + uint64 current_block = 1; // Check for relevance within a range of +/- N blocks + // Signer can specify any message here + string message = 2; + // Signature of the following message: + // ("user specified message", user_address, current_block_number) + bytes signature = 3; + string signer_address = 4; +} + +message NewModelRequest { + AuthorizationDetails authorization = 1; + training.NewModel model = 2; +} + +message AuthValidateRequest { + AuthorizationDetails authorization = 1; + string model_id = 2; + string training_data_link = 3; +} + +message UploadAndValidateRequest { + AuthorizationDetails authorization = 1; + training.UploadInput upload_input = 2; +} + +message CommonRequest { + AuthorizationDetails authorization = 1; + string model_id = 2; +} + +message UpdateModelRequest { + AuthorizationDetails authorization = 1; + string model_id = 2; + optional string model_name = 3; + optional string description = 4; + repeated string address_list = 5; +} + +message ModelsResponse { + repeated training.ModelResponse list_of_models = 1; +} + +// These methods are implemented only by the daemon; the service provider should ignore them +service Daemon { + // Free + rpc create_model(NewModelRequest) returns (training.ModelResponse) {} + + // Free + rpc validate_model_price(AuthValidateRequest) returns (training.PriceInBaseUnit) {} + + // Paid + rpc upload_and_validate(stream UploadAndValidateRequest) returns (training.StatusResponse) {} + + // Paid + rpc validate_model(AuthValidateRequest) returns (training.StatusResponse) {} + + // Free, one signature for both train_model_price & train_model methods + rpc train_model_price(CommonRequest) returns (training.PriceInBaseUnit) {} + + // Paid + rpc train_model(CommonRequest) returns (training.StatusResponse) {} + + // Free + // After deleting the model, the status becomes DELETED in etcd + rpc delete_model(CommonRequest) returns (training.StatusResponse) {} + + rpc get_all_models(AllModelsRequest) returns (ModelsResponse) {} + + rpc get_model(CommonRequest) returns (training.ModelResponse) {} + + rpc update_model(UpdateModelRequest) returns (training.ModelResponse) {} + + // Unique methods by daemon + // One signature for all getters + rpc get_training_metadata(google.protobuf.Empty) returns (TrainingMetadata) {} + + // Free & without authorization + rpc get_method_metadata(MethodMetadataRequest) returns (MethodMetadata) {} +} + +message MethodMetadataRequest { + string model_id = 1; + // Model ID or gRPC method name + string grpc_method_name = 2; + string grpc_service_name = 3; +} + +message AllModelsRequest { + AuthorizationDetails authorization = 1; + // filters: + repeated training.Status statuses = 3; + optional bool is_public = 4; // null - all, false - only private, true - only public models + string grpc_method_name = 5; + string grpc_service_name = 6; + string name = 7; + string created_by_address = 8; + uint64 page_size = 9; + uint64 page = 10; +} + +message TrainingMetadata { + bool trainingEnabled = 1; + bool trainingInProto = 2; + // Key: grpc_service_name, Value: array of grpc_method_name + map trainingMethods = 3; +} + +message MethodMetadata { + string default_model_id = 50001; + uint64 max_models_per_user = 50002; // max models per method & user + uint64 dataset_max_size_mb = 50003; // max size of dataset + uint64 dataset_max_count_files = 50004; // maximum number of files in the dataset + uint64 dataset_max_size_single_file_mb = 50005; // maximum size of a single file in the dataset + string dataset_files_type = 50006; // allowed files types in dataset. string with array or single value, example: jpg, png, mp3 + string dataset_type = 50007; // string with array or single value, example: zip, tar.gz, tar + string dataset_description = 50008; // additional free-form requirements +} diff --git a/snet/sdk/resources/proto/training_daemon_pb2.py b/snet/sdk/resources/proto/training_daemon_pb2.py new file mode 100644 index 0000000..6923330 --- /dev/null +++ b/snet/sdk/resources/proto/training_daemon_pb2.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: training_daemon.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2 +from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 +import training_pb2 as training__pb2 +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15training_daemon.proto\x12\x08training\x1a google/protobuf/descriptor.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x0etraining.proto\x1a\x1bgoogle/protobuf/empty.proto\"i\n\x14\x41uthorizationDetails\x12\x15\n\rcurrent_block\x18\x01 \x01(\x04\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x11\n\tsignature\x18\x03 \x01(\x0c\x12\x16\n\x0esigner_address\x18\x04 \x01(\t\"k\n\x0fNewModelRequest\x12\x35\n\rauthorization\x18\x01 \x01(\x0b\x32\x1e.training.AuthorizationDetails\x12!\n\x05model\x18\x02 \x01(\x0b\x32\x12.training.NewModel\"z\n\x13\x41uthValidateRequest\x12\x35\n\rauthorization\x18\x01 \x01(\x0b\x32\x1e.training.AuthorizationDetails\x12\x10\n\x08model_id\x18\x02 \x01(\t\x12\x1a\n\x12training_data_link\x18\x03 \x01(\t\"~\n\x18UploadAndValidateRequest\x12\x35\n\rauthorization\x18\x01 \x01(\x0b\x32\x1e.training.AuthorizationDetails\x12+\n\x0cupload_input\x18\x02 \x01(\x0b\x32\x15.training.UploadInput\"X\n\rCommonRequest\x12\x35\n\rauthorization\x18\x01 \x01(\x0b\x32\x1e.training.AuthorizationDetails\x12\x10\n\x08model_id\x18\x02 \x01(\t\"\xc5\x01\n\x12UpdateModelRequest\x12\x35\n\rauthorization\x18\x01 \x01(\x0b\x32\x1e.training.AuthorizationDetails\x12\x10\n\x08model_id\x18\x02 \x01(\t\x12\x17\n\nmodel_name\x18\x03 \x01(\tH\x00\x88\x01\x01\x12\x18\n\x0b\x64\x65scription\x18\x04 \x01(\tH\x01\x88\x01\x01\x12\x14\n\x0c\x61\x64\x64ress_list\x18\x05 \x03(\tB\r\n\x0b_model_nameB\x0e\n\x0c_description\"A\n\x0eModelsResponse\x12/\n\x0elist_of_models\x18\x01 \x03(\x0b\x32\x17.training.ModelResponse\"^\n\x15MethodMetadataRequest\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x18\n\x10grpc_method_name\x18\x02 \x01(\t\x12\x19\n\x11grpc_service_name\x18\x03 \x01(\t\"\x93\x02\n\x10\x41llModelsRequest\x12\x35\n\rauthorization\x18\x01 \x01(\x0b\x32\x1e.training.AuthorizationDetails\x12\"\n\x08statuses\x18\x03 \x03(\x0e\x32\x10.training.Status\x12\x16\n\tis_public\x18\x04 \x01(\x08H\x00\x88\x01\x01\x12\x18\n\x10grpc_method_name\x18\x05 \x01(\t\x12\x19\n\x11grpc_service_name\x18\x06 \x01(\t\x12\x0c\n\x04name\x18\x07 \x01(\t\x12\x1a\n\x12\x63reated_by_address\x18\x08 \x01(\t\x12\x11\n\tpage_size\x18\t \x01(\x04\x12\x0c\n\x04page\x18\n \x01(\x04\x42\x0c\n\n_is_public\"\xe2\x01\n\x10TrainingMetadata\x12\x17\n\x0ftrainingEnabled\x18\x01 \x01(\x08\x12\x17\n\x0ftrainingInProto\x18\x02 \x01(\x08\x12H\n\x0ftrainingMethods\x18\x03 \x03(\x0b\x32/.training.TrainingMetadata.TrainingMethodsEntry\x1aR\n\x14TrainingMethodsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12)\n\x05value\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.ListValue:\x02\x38\x01\"\x8d\x02\n\x0eMethodMetadata\x12\x1a\n\x10\x64\x65\x66\x61ult_model_id\x18\xd1\x86\x03 \x01(\t\x12\x1d\n\x13max_models_per_user\x18\xd2\x86\x03 \x01(\x04\x12\x1d\n\x13\x64\x61taset_max_size_mb\x18\xd3\x86\x03 \x01(\x04\x12!\n\x17\x64\x61taset_max_count_files\x18\xd4\x86\x03 \x01(\x04\x12)\n\x1f\x64\x61taset_max_size_single_file_mb\x18\xd5\x86\x03 \x01(\x04\x12\x1c\n\x12\x64\x61taset_files_type\x18\xd6\x86\x03 \x01(\t\x12\x16\n\x0c\x64\x61taset_type\x18\xd7\x86\x03 \x01(\t\x12\x1d\n\x13\x64\x61taset_description\x18\xd8\x86\x03 \x01(\t2\x93\x07\n\x06\x44\x61\x65mon\x12\x44\n\x0c\x63reate_model\x12\x19.training.NewModelRequest\x1a\x17.training.ModelResponse\"\x00\x12R\n\x14validate_model_price\x12\x1d.training.AuthValidateRequest\x1a\x19.training.PriceInBaseUnit\"\x00\x12W\n\x13upload_and_validate\x12\".training.UploadAndValidateRequest\x1a\x18.training.StatusResponse\"\x00(\x01\x12K\n\x0evalidate_model\x12\x1d.training.AuthValidateRequest\x1a\x18.training.StatusResponse\"\x00\x12I\n\x11train_model_price\x12\x17.training.CommonRequest\x1a\x19.training.PriceInBaseUnit\"\x00\x12\x42\n\x0btrain_model\x12\x17.training.CommonRequest\x1a\x18.training.StatusResponse\"\x00\x12\x43\n\x0c\x64\x65lete_model\x12\x17.training.CommonRequest\x1a\x18.training.StatusResponse\"\x00\x12H\n\x0eget_all_models\x12\x1a.training.AllModelsRequest\x1a\x18.training.ModelsResponse\"\x00\x12?\n\tget_model\x12\x17.training.CommonRequest\x1a\x17.training.ModelResponse\"\x00\x12G\n\x0cupdate_model\x12\x1c.training.UpdateModelRequest\x1a\x17.training.ModelResponse\"\x00\x12M\n\x15get_training_metadata\x12\x16.google.protobuf.Empty\x1a\x1a.training.TrainingMetadata\"\x00\x12R\n\x13get_method_metadata\x12\x1f.training.MethodMetadataRequest\x1a\x18.training.MethodMetadata\"\x00\x42,Z*github.com/singnet/snet-daemon/v5/trainingb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'training_daemon_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'Z*github.com/singnet/snet-daemon/v5/training' + _TRAININGMETADATA_TRAININGMETHODSENTRY._options = None + _TRAININGMETADATA_TRAININGMETHODSENTRY._serialized_options = b'8\001' + _globals['_AUTHORIZATIONDETAILS']._serialized_start=144 + _globals['_AUTHORIZATIONDETAILS']._serialized_end=249 + _globals['_NEWMODELREQUEST']._serialized_start=251 + _globals['_NEWMODELREQUEST']._serialized_end=358 + _globals['_AUTHVALIDATEREQUEST']._serialized_start=360 + _globals['_AUTHVALIDATEREQUEST']._serialized_end=482 + _globals['_UPLOADANDVALIDATEREQUEST']._serialized_start=484 + _globals['_UPLOADANDVALIDATEREQUEST']._serialized_end=610 + _globals['_COMMONREQUEST']._serialized_start=612 + _globals['_COMMONREQUEST']._serialized_end=700 + _globals['_UPDATEMODELREQUEST']._serialized_start=703 + _globals['_UPDATEMODELREQUEST']._serialized_end=900 + _globals['_MODELSRESPONSE']._serialized_start=902 + _globals['_MODELSRESPONSE']._serialized_end=967 + _globals['_METHODMETADATAREQUEST']._serialized_start=969 + _globals['_METHODMETADATAREQUEST']._serialized_end=1063 + _globals['_ALLMODELSREQUEST']._serialized_start=1066 + _globals['_ALLMODELSREQUEST']._serialized_end=1341 + _globals['_TRAININGMETADATA']._serialized_start=1344 + _globals['_TRAININGMETADATA']._serialized_end=1570 + _globals['_TRAININGMETADATA_TRAININGMETHODSENTRY']._serialized_start=1488 + _globals['_TRAININGMETADATA_TRAININGMETHODSENTRY']._serialized_end=1570 + _globals['_METHODMETADATA']._serialized_start=1573 + _globals['_METHODMETADATA']._serialized_end=1842 + _globals['_DAEMON']._serialized_start=1845 + _globals['_DAEMON']._serialized_end=2760 +# @@protoc_insertion_point(module_scope) diff --git a/snet/sdk/resources/proto/training_daemon_pb2_grpc.py b/snet/sdk/resources/proto/training_daemon_pb2_grpc.py new file mode 100644 index 0000000..24378b1 --- /dev/null +++ b/snet/sdk/resources/proto/training_daemon_pb2_grpc.py @@ -0,0 +1,445 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from google.protobuf import empty_pb2 as google_dot_protobuf_dot_empty__pb2 +import training_daemon_pb2 as training__daemon__pb2 +import training_pb2 as training__pb2 + + +class DaemonStub(object): + """These methods are implemented only by the daemon; the service provider should ignore them + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.create_model = channel.unary_unary( + '/training.Daemon/create_model', + request_serializer=training__daemon__pb2.NewModelRequest.SerializeToString, + response_deserializer=training__pb2.ModelResponse.FromString, + ) + self.validate_model_price = channel.unary_unary( + '/training.Daemon/validate_model_price', + request_serializer=training__daemon__pb2.AuthValidateRequest.SerializeToString, + response_deserializer=training__pb2.PriceInBaseUnit.FromString, + ) + self.upload_and_validate = channel.stream_unary( + '/training.Daemon/upload_and_validate', + request_serializer=training__daemon__pb2.UploadAndValidateRequest.SerializeToString, + response_deserializer=training__pb2.StatusResponse.FromString, + ) + self.validate_model = channel.unary_unary( + '/training.Daemon/validate_model', + request_serializer=training__daemon__pb2.AuthValidateRequest.SerializeToString, + response_deserializer=training__pb2.StatusResponse.FromString, + ) + self.train_model_price = channel.unary_unary( + '/training.Daemon/train_model_price', + request_serializer=training__daemon__pb2.CommonRequest.SerializeToString, + response_deserializer=training__pb2.PriceInBaseUnit.FromString, + ) + self.train_model = channel.unary_unary( + '/training.Daemon/train_model', + request_serializer=training__daemon__pb2.CommonRequest.SerializeToString, + response_deserializer=training__pb2.StatusResponse.FromString, + ) + self.delete_model = channel.unary_unary( + '/training.Daemon/delete_model', + request_serializer=training__daemon__pb2.CommonRequest.SerializeToString, + response_deserializer=training__pb2.StatusResponse.FromString, + ) + self.get_all_models = channel.unary_unary( + '/training.Daemon/get_all_models', + request_serializer=training__daemon__pb2.AllModelsRequest.SerializeToString, + response_deserializer=training__daemon__pb2.ModelsResponse.FromString, + ) + self.get_model = channel.unary_unary( + '/training.Daemon/get_model', + request_serializer=training__daemon__pb2.CommonRequest.SerializeToString, + response_deserializer=training__pb2.ModelResponse.FromString, + ) + self.update_model = channel.unary_unary( + '/training.Daemon/update_model', + request_serializer=training__daemon__pb2.UpdateModelRequest.SerializeToString, + response_deserializer=training__pb2.ModelResponse.FromString, + ) + self.get_training_metadata = channel.unary_unary( + '/training.Daemon/get_training_metadata', + request_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + response_deserializer=training__daemon__pb2.TrainingMetadata.FromString, + ) + self.get_method_metadata = channel.unary_unary( + '/training.Daemon/get_method_metadata', + request_serializer=training__daemon__pb2.MethodMetadataRequest.SerializeToString, + response_deserializer=training__daemon__pb2.MethodMetadata.FromString, + ) + + +class DaemonServicer(object): + """These methods are implemented only by the daemon; the service provider should ignore them + """ + + def create_model(self, request, context): + """Free + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def validate_model_price(self, request, context): + """Free + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def upload_and_validate(self, request_iterator, context): + """Paid + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def validate_model(self, request, context): + """Paid + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def train_model_price(self, request, context): + """Free, one signature for both train_model_price & train_model methods + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def train_model(self, request, context): + """Paid + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def delete_model(self, request, context): + """Free + After deleting the model, the status becomes DELETED in etcd + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def get_all_models(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def get_model(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def update_model(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def get_training_metadata(self, request, context): + """Unique methods by daemon + One signature for all getters + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def get_method_metadata(self, request, context): + """Free & without authorization + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_DaemonServicer_to_server(servicer, server): + rpc_method_handlers = { + 'create_model': grpc.unary_unary_rpc_method_handler( + servicer.create_model, + request_deserializer=training__daemon__pb2.NewModelRequest.FromString, + response_serializer=training__pb2.ModelResponse.SerializeToString, + ), + 'validate_model_price': grpc.unary_unary_rpc_method_handler( + servicer.validate_model_price, + request_deserializer=training__daemon__pb2.AuthValidateRequest.FromString, + response_serializer=training__pb2.PriceInBaseUnit.SerializeToString, + ), + 'upload_and_validate': grpc.stream_unary_rpc_method_handler( + servicer.upload_and_validate, + request_deserializer=training__daemon__pb2.UploadAndValidateRequest.FromString, + response_serializer=training__pb2.StatusResponse.SerializeToString, + ), + 'validate_model': grpc.unary_unary_rpc_method_handler( + servicer.validate_model, + request_deserializer=training__daemon__pb2.AuthValidateRequest.FromString, + response_serializer=training__pb2.StatusResponse.SerializeToString, + ), + 'train_model_price': grpc.unary_unary_rpc_method_handler( + servicer.train_model_price, + request_deserializer=training__daemon__pb2.CommonRequest.FromString, + response_serializer=training__pb2.PriceInBaseUnit.SerializeToString, + ), + 'train_model': grpc.unary_unary_rpc_method_handler( + servicer.train_model, + request_deserializer=training__daemon__pb2.CommonRequest.FromString, + response_serializer=training__pb2.StatusResponse.SerializeToString, + ), + 'delete_model': grpc.unary_unary_rpc_method_handler( + servicer.delete_model, + request_deserializer=training__daemon__pb2.CommonRequest.FromString, + response_serializer=training__pb2.StatusResponse.SerializeToString, + ), + 'get_all_models': grpc.unary_unary_rpc_method_handler( + servicer.get_all_models, + request_deserializer=training__daemon__pb2.AllModelsRequest.FromString, + response_serializer=training__daemon__pb2.ModelsResponse.SerializeToString, + ), + 'get_model': grpc.unary_unary_rpc_method_handler( + servicer.get_model, + request_deserializer=training__daemon__pb2.CommonRequest.FromString, + response_serializer=training__pb2.ModelResponse.SerializeToString, + ), + 'update_model': grpc.unary_unary_rpc_method_handler( + servicer.update_model, + request_deserializer=training__daemon__pb2.UpdateModelRequest.FromString, + response_serializer=training__pb2.ModelResponse.SerializeToString, + ), + 'get_training_metadata': grpc.unary_unary_rpc_method_handler( + servicer.get_training_metadata, + request_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, + response_serializer=training__daemon__pb2.TrainingMetadata.SerializeToString, + ), + 'get_method_metadata': grpc.unary_unary_rpc_method_handler( + servicer.get_method_metadata, + request_deserializer=training__daemon__pb2.MethodMetadataRequest.FromString, + response_serializer=training__daemon__pb2.MethodMetadata.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'training.Daemon', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Daemon(object): + """These methods are implemented only by the daemon; the service provider should ignore them + """ + + @staticmethod + def create_model(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Daemon/create_model', + training__daemon__pb2.NewModelRequest.SerializeToString, + training__pb2.ModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def validate_model_price(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Daemon/validate_model_price', + training__daemon__pb2.AuthValidateRequest.SerializeToString, + training__pb2.PriceInBaseUnit.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def upload_and_validate(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary(request_iterator, target, '/training.Daemon/upload_and_validate', + training__daemon__pb2.UploadAndValidateRequest.SerializeToString, + training__pb2.StatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def validate_model(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Daemon/validate_model', + training__daemon__pb2.AuthValidateRequest.SerializeToString, + training__pb2.StatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def train_model_price(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Daemon/train_model_price', + training__daemon__pb2.CommonRequest.SerializeToString, + training__pb2.PriceInBaseUnit.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def train_model(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Daemon/train_model', + training__daemon__pb2.CommonRequest.SerializeToString, + training__pb2.StatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def delete_model(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Daemon/delete_model', + training__daemon__pb2.CommonRequest.SerializeToString, + training__pb2.StatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def get_all_models(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Daemon/get_all_models', + training__daemon__pb2.AllModelsRequest.SerializeToString, + training__daemon__pb2.ModelsResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def get_model(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Daemon/get_model', + training__daemon__pb2.CommonRequest.SerializeToString, + training__pb2.ModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def update_model(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Daemon/update_model', + training__daemon__pb2.UpdateModelRequest.SerializeToString, + training__pb2.ModelResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def get_training_metadata(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Daemon/get_training_metadata', + google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, + training__daemon__pb2.TrainingMetadata.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def get_method_metadata(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Daemon/get_method_metadata', + training__daemon__pb2.MethodMetadataRequest.SerializeToString, + training__daemon__pb2.MethodMetadata.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/snet/sdk/resources/proto/training_pb2.py b/snet/sdk/resources/proto/training_pb2.py index 735e256..728189e 100644 --- a/snet/sdk/resources/proto/training_pb2.py +++ b/snet/sdk/resources/proto/training_pb2.py @@ -14,34 +14,30 @@ from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0etraining.proto\x12\x08training\x1a google/protobuf/descriptor.proto\"\xb5\x02\n\x0cModelDetails\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x18\n\x10grpc_method_name\x18\x02 \x01(\t\x12\x19\n\x11grpc_service_name\x18\x03 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x04 \x01(\t\x12\x0e\n\x06status\x18\x06 \x01(\t\x12\x14\n\x0cupdated_date\x18\x07 \x01(\t\x12\x14\n\x0c\x61\x64\x64ress_list\x18\x08 \x03(\t\x12\x1a\n\x12training_data_link\x18\t \x01(\t\x12\x12\n\nmodel_name\x18\n \x01(\t\x12\x17\n\x0forganization_id\x18\x0b \x01(\t\x12\x12\n\nservice_id\x18\x0c \x01(\t\x12\x10\n\x08group_id\x18\r \x01(\t\x12\x1e\n\x16is_publicly_accessible\x18\x0e \x01(\x08\"i\n\x14\x41uthorizationDetails\x12\x15\n\rcurrent_block\x18\x01 \x01(\x04\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x11\n\tsignature\x18\x03 \x01(\x0c\x12\x16\n\x0esigner_address\x18\x04 \x01(\t\"z\n\x12\x43reateModelRequest\x12\x35\n\rauthorization\x18\x01 \x01(\x0b\x32\x1e.training.AuthorizationDetails\x12-\n\rmodel_details\x18\x02 \x01(\x0b\x32\x16.training.ModelDetails\"\x85\x01\n\x17\x41\x63\x63\x65ssibleModelsRequest\x12\x18\n\x10grpc_method_name\x18\x01 \x01(\t\x12\x19\n\x11grpc_service_name\x18\x02 \x01(\t\x12\x35\n\rauthorization\x18\x03 \x01(\x0b\x32\x1e.training.AuthorizationDetails\"J\n\x18\x41\x63\x63\x65ssibleModelsResponse\x12.\n\x0elist_of_models\x18\x01 \x03(\x0b\x32\x16.training.ModelDetails\"{\n\x13ModelDetailsRequest\x12-\n\rmodel_details\x18\x01 \x01(\x0b\x32\x16.training.ModelDetails\x12\x35\n\rauthorization\x18\x02 \x01(\x0b\x32\x1e.training.AuthorizationDetails\"7\n\x14TrainingMethodOption\x12\x1f\n\x17trainingMethodIndicator\x18\x01 \x01(\t\"\x81\x01\n\x12UpdateModelRequest\x12\x34\n\x14update_model_details\x18\x01 \x01(\x0b\x32\x16.training.ModelDetails\x12\x35\n\rauthorization\x18\x02 \x01(\x0b\x32\x1e.training.AuthorizationDetails\"g\n\x14ModelDetailsResponse\x12 \n\x06status\x18\x01 \x01(\x0e\x32\x10.training.Status\x12-\n\rmodel_details\x18\x02 \x01(\x0b\x32\x16.training.ModelDetails*O\n\x06Status\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0f\n\x0bIN_PROGRESS\x10\x01\x12\x0b\n\x07\x45RRORED\x10\x02\x12\r\n\tCOMPLETED\x10\x03\x12\x0b\n\x07\x44\x45LETED\x10\x04\x32\xae\x03\n\x05Model\x12N\n\x0c\x63reate_model\x12\x1c.training.CreateModelRequest\x1a\x1e.training.ModelDetailsResponse\"\x00\x12N\n\x0c\x64\x65lete_model\x12\x1c.training.UpdateModelRequest\x1a\x1e.training.ModelDetailsResponse\"\x00\x12S\n\x10get_model_status\x12\x1d.training.ModelDetailsRequest\x1a\x1e.training.ModelDetailsResponse\"\x00\x12U\n\x13update_model_access\x12\x1c.training.UpdateModelRequest\x1a\x1e.training.ModelDetailsResponse\"\x00\x12Y\n\x0eget_all_models\x12!.training.AccessibleModelsRequest\x1a\".training.AccessibleModelsResponse\"\x00:[\n\x10my_method_option\x12\x1e.google.protobuf.MethodOptions\x18\xdd\xa6\xe2\x04 \x01(\x0b\x32\x1e.training.TrainingMethodOptionB\rZ\x0b../trainingb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0etraining.proto\x12\x08training\x1a google/protobuf/descriptor.proto\"\xc4\x02\n\rModelResponse\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12 \n\x06status\x18\x02 \x01(\x0e\x32\x10.training.Status\x12\x14\n\x0c\x63reated_date\x18\x03 \x01(\t\x12\x14\n\x0cupdated_date\x18\x04 \x01(\t\x12\x0c\n\x04name\x18\x05 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x06 \x01(\t\x12\x18\n\x10grpc_method_name\x18\x07 \x01(\t\x12\x19\n\x11grpc_service_name\x18\x08 \x01(\t\x12\x14\n\x0c\x61\x64\x64ress_list\x18\t \x03(\t\x12\x11\n\tis_public\x18\n \x01(\x08\x12\x1a\n\x12training_data_link\x18\x0b \x01(\t\x12\x1a\n\x12\x63reated_by_address\x18\x0c \x01(\t\x12\x1a\n\x12updated_by_address\x18\r \x01(\t\"\xca\x01\n\x08NewModel\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x18\n\x10grpc_method_name\x18\x03 \x01(\t\x12\x19\n\x11grpc_service_name\x18\x04 \x01(\t\x12\x14\n\x0c\x61\x64\x64ress_list\x18\x05 \x03(\t\x12\x11\n\tis_public\x18\x06 \x01(\x08\x12\x17\n\x0forganization_id\x18\x07 \x01(\t\x12\x12\n\nservice_id\x18\x08 \x01(\t\x12\x10\n\x08group_id\x18\t \x01(\t\"\x1b\n\x07ModelID\x12\x10\n\x08model_id\x18\x01 \x01(\t\" \n\x0fPriceInBaseUnit\x12\r\n\x05price\x18\x01 \x01(\x04\"2\n\x0eStatusResponse\x12 \n\x06status\x18\x01 \x01(\x0e\x32\x10.training.Status\"\x92\x01\n\x0bUploadInput\x12\x10\n\x08model_id\x18\x01 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\x12\x11\n\tfile_name\x18\x03 \x01(\t\x12\x11\n\tfile_size\x18\x04 \x01(\x04\x12\x12\n\nbatch_size\x18\x05 \x01(\x04\x12\x14\n\x0c\x62\x61tch_number\x18\x06 \x01(\x04\x12\x13\n\x0b\x62\x61tch_count\x18\x07 \x01(\x04\"?\n\x0fValidateRequest\x12\x10\n\x08model_id\x18\x02 \x01(\t\x12\x1a\n\x12training_data_link\x18\x03 \x01(\t*n\n\x06Status\x12\x0b\n\x07\x43REATED\x10\x00\x12\x0e\n\nVALIDATING\x10\x01\x12\r\n\tVALIDATED\x10\x02\x12\x0c\n\x08TRAINING\x10\x03\x12\x10\n\x0cREADY_TO_USE\x10\x04\x12\x0b\n\x07\x45RRORED\x10\x05\x12\x0b\n\x07\x44\x45LETED\x10\x06\x32\xaa\x04\n\x05Model\x12\x37\n\x0c\x63reate_model\x12\x12.training.NewModel\x1a\x11.training.ModelID\"\x00\x12N\n\x14validate_model_price\x12\x19.training.ValidateRequest\x1a\x19.training.PriceInBaseUnit\"\x00\x12J\n\x13upload_and_validate\x12\x15.training.UploadInput\x1a\x18.training.StatusResponse\"\x00(\x01\x12G\n\x0evalidate_model\x12\x19.training.ValidateRequest\x1a\x18.training.StatusResponse\"\x00\x12\x43\n\x11train_model_price\x12\x11.training.ModelID\x1a\x19.training.PriceInBaseUnit\"\x00\x12<\n\x0btrain_model\x12\x11.training.ModelID\x1a\x18.training.StatusResponse\"\x00\x12=\n\x0c\x64\x65lete_model\x12\x11.training.ModelID\x1a\x18.training.StatusResponse\"\x00\x12\x41\n\x10get_model_status\x12\x11.training.ModelID\x1a\x18.training.StatusResponse\"\x00::\n\x10\x64\x65\x66\x61ult_model_id\x12\x1e.google.protobuf.MethodOptions\x18\xd1\x86\x03 \x01(\t:=\n\x13max_models_per_user\x12\x1e.google.protobuf.MethodOptions\x18\xd2\x86\x03 \x01(\x04:=\n\x13\x64\x61taset_max_size_mb\x12\x1e.google.protobuf.MethodOptions\x18\xd3\x86\x03 \x01(\x04:A\n\x17\x64\x61taset_max_count_files\x12\x1e.google.protobuf.MethodOptions\x18\xd4\x86\x03 \x01(\x04:I\n\x1f\x64\x61taset_max_size_single_file_mb\x12\x1e.google.protobuf.MethodOptions\x18\xd5\x86\x03 \x01(\x04:<\n\x12\x64\x61taset_files_type\x12\x1e.google.protobuf.MethodOptions\x18\xd6\x86\x03 \x01(\t:6\n\x0c\x64\x61taset_type\x12\x1e.google.protobuf.MethodOptions\x18\xd7\x86\x03 \x01(\t:=\n\x13\x64\x61taset_description\x12\x1e.google.protobuf.MethodOptions\x18\xd8\x86\x03 \x01(\tB,Z*github.com/singnet/snet-daemon/v5/trainingb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'training_pb2', _globals) if _descriptor._USE_C_DESCRIPTORS == False: DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'Z\013../training' - _globals['_STATUS']._serialized_start=1236 - _globals['_STATUS']._serialized_end=1315 - _globals['_MODELDETAILS']._serialized_start=63 - _globals['_MODELDETAILS']._serialized_end=372 - _globals['_AUTHORIZATIONDETAILS']._serialized_start=374 - _globals['_AUTHORIZATIONDETAILS']._serialized_end=479 - _globals['_CREATEMODELREQUEST']._serialized_start=481 - _globals['_CREATEMODELREQUEST']._serialized_end=603 - _globals['_ACCESSIBLEMODELSREQUEST']._serialized_start=606 - _globals['_ACCESSIBLEMODELSREQUEST']._serialized_end=739 - _globals['_ACCESSIBLEMODELSRESPONSE']._serialized_start=741 - _globals['_ACCESSIBLEMODELSRESPONSE']._serialized_end=815 - _globals['_MODELDETAILSREQUEST']._serialized_start=817 - _globals['_MODELDETAILSREQUEST']._serialized_end=940 - _globals['_TRAININGMETHODOPTION']._serialized_start=942 - _globals['_TRAININGMETHODOPTION']._serialized_end=997 - _globals['_UPDATEMODELREQUEST']._serialized_start=1000 - _globals['_UPDATEMODELREQUEST']._serialized_end=1129 - _globals['_MODELDETAILSRESPONSE']._serialized_start=1131 - _globals['_MODELDETAILSRESPONSE']._serialized_end=1234 - _globals['_MODEL']._serialized_start=1318 - _globals['_MODEL']._serialized_end=1748 + DESCRIPTOR._serialized_options = b'Z*github.com/singnet/snet-daemon/v5/training' + _globals['_STATUS']._serialized_start=923 + _globals['_STATUS']._serialized_end=1033 + _globals['_MODELRESPONSE']._serialized_start=63 + _globals['_MODELRESPONSE']._serialized_end=387 + _globals['_NEWMODEL']._serialized_start=390 + _globals['_NEWMODEL']._serialized_end=592 + _globals['_MODELID']._serialized_start=594 + _globals['_MODELID']._serialized_end=621 + _globals['_PRICEINBASEUNIT']._serialized_start=623 + _globals['_PRICEINBASEUNIT']._serialized_end=655 + _globals['_STATUSRESPONSE']._serialized_start=657 + _globals['_STATUSRESPONSE']._serialized_end=707 + _globals['_UPLOADINPUT']._serialized_start=710 + _globals['_UPLOADINPUT']._serialized_end=856 + _globals['_VALIDATEREQUEST']._serialized_start=858 + _globals['_VALIDATEREQUEST']._serialized_end=921 + _globals['_MODEL']._serialized_start=1036 + _globals['_MODEL']._serialized_end=1590 # @@protoc_insertion_point(module_scope) diff --git a/snet/sdk/resources/proto/training_pb2_grpc.py b/snet/sdk/resources/proto/training_pb2_grpc.py index 51ccf27..480dfcf 100644 --- a/snet/sdk/resources/proto/training_pb2_grpc.py +++ b/snet/sdk/resources/proto/training_pb2_grpc.py @@ -6,7 +6,8 @@ class ModelStub(object): - """Missing associated documentation comment in .proto file.""" + """Methods that the service provider must implement + """ def __init__(self, channel): """Constructor. @@ -16,65 +17,104 @@ def __init__(self, channel): """ self.create_model = channel.unary_unary( '/training.Model/create_model', - request_serializer=training__pb2.CreateModelRequest.SerializeToString, - response_deserializer=training__pb2.ModelDetailsResponse.FromString, + request_serializer=training__pb2.NewModel.SerializeToString, + response_deserializer=training__pb2.ModelID.FromString, + ) + self.validate_model_price = channel.unary_unary( + '/training.Model/validate_model_price', + request_serializer=training__pb2.ValidateRequest.SerializeToString, + response_deserializer=training__pb2.PriceInBaseUnit.FromString, + ) + self.upload_and_validate = channel.stream_unary( + '/training.Model/upload_and_validate', + request_serializer=training__pb2.UploadInput.SerializeToString, + response_deserializer=training__pb2.StatusResponse.FromString, + ) + self.validate_model = channel.unary_unary( + '/training.Model/validate_model', + request_serializer=training__pb2.ValidateRequest.SerializeToString, + response_deserializer=training__pb2.StatusResponse.FromString, + ) + self.train_model_price = channel.unary_unary( + '/training.Model/train_model_price', + request_serializer=training__pb2.ModelID.SerializeToString, + response_deserializer=training__pb2.PriceInBaseUnit.FromString, + ) + self.train_model = channel.unary_unary( + '/training.Model/train_model', + request_serializer=training__pb2.ModelID.SerializeToString, + response_deserializer=training__pb2.StatusResponse.FromString, ) self.delete_model = channel.unary_unary( '/training.Model/delete_model', - request_serializer=training__pb2.UpdateModelRequest.SerializeToString, - response_deserializer=training__pb2.ModelDetailsResponse.FromString, + request_serializer=training__pb2.ModelID.SerializeToString, + response_deserializer=training__pb2.StatusResponse.FromString, ) self.get_model_status = channel.unary_unary( '/training.Model/get_model_status', - request_serializer=training__pb2.ModelDetailsRequest.SerializeToString, - response_deserializer=training__pb2.ModelDetailsResponse.FromString, - ) - self.update_model_access = channel.unary_unary( - '/training.Model/update_model_access', - request_serializer=training__pb2.UpdateModelRequest.SerializeToString, - response_deserializer=training__pb2.ModelDetailsResponse.FromString, - ) - self.get_all_models = channel.unary_unary( - '/training.Model/get_all_models', - request_serializer=training__pb2.AccessibleModelsRequest.SerializeToString, - response_deserializer=training__pb2.AccessibleModelsResponse.FromString, + request_serializer=training__pb2.ModelID.SerializeToString, + response_deserializer=training__pb2.StatusResponse.FromString, ) class ModelServicer(object): - """Missing associated documentation comment in .proto file.""" + """Methods that the service provider must implement + """ def create_model(self, request, context): - """The AI developer needs to Implement this service and Daemon will call these - There will be no cost borne by the consumer in calling these methods, - Pricing will apply when you actually call the training methods defined. - AI consumer will call all these methods + """Free + Can pass the address of the model creator """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def delete_model(self, request, context): - """Missing associated documentation comment in .proto file.""" + def validate_model_price(self, request, context): + """Free + """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def get_model_status(self, request, context): - """Missing associated documentation comment in .proto file.""" + def upload_and_validate(self, request_iterator, context): + """Paid + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def validate_model(self, request, context): + """Paid + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def train_model_price(self, request, context): + """Free, one signature for both train_model_price & train_model methods + """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def update_model_access(self, request, context): - """Daemon will implement , however the AI developer should skip implementing these and just provide dummy code. + def train_model(self, request, context): + """Paid """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def get_all_models(self, request, context): - """Missing associated documentation comment in .proto file.""" + def delete_model(self, request, context): + """Free + After model deletion, the status becomes DELETED in etcd + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def get_model_status(self, request, context): + """Free + """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') @@ -84,28 +124,43 @@ def add_ModelServicer_to_server(servicer, server): rpc_method_handlers = { 'create_model': grpc.unary_unary_rpc_method_handler( servicer.create_model, - request_deserializer=training__pb2.CreateModelRequest.FromString, - response_serializer=training__pb2.ModelDetailsResponse.SerializeToString, + request_deserializer=training__pb2.NewModel.FromString, + response_serializer=training__pb2.ModelID.SerializeToString, + ), + 'validate_model_price': grpc.unary_unary_rpc_method_handler( + servicer.validate_model_price, + request_deserializer=training__pb2.ValidateRequest.FromString, + response_serializer=training__pb2.PriceInBaseUnit.SerializeToString, + ), + 'upload_and_validate': grpc.stream_unary_rpc_method_handler( + servicer.upload_and_validate, + request_deserializer=training__pb2.UploadInput.FromString, + response_serializer=training__pb2.StatusResponse.SerializeToString, + ), + 'validate_model': grpc.unary_unary_rpc_method_handler( + servicer.validate_model, + request_deserializer=training__pb2.ValidateRequest.FromString, + response_serializer=training__pb2.StatusResponse.SerializeToString, + ), + 'train_model_price': grpc.unary_unary_rpc_method_handler( + servicer.train_model_price, + request_deserializer=training__pb2.ModelID.FromString, + response_serializer=training__pb2.PriceInBaseUnit.SerializeToString, + ), + 'train_model': grpc.unary_unary_rpc_method_handler( + servicer.train_model, + request_deserializer=training__pb2.ModelID.FromString, + response_serializer=training__pb2.StatusResponse.SerializeToString, ), 'delete_model': grpc.unary_unary_rpc_method_handler( servicer.delete_model, - request_deserializer=training__pb2.UpdateModelRequest.FromString, - response_serializer=training__pb2.ModelDetailsResponse.SerializeToString, + request_deserializer=training__pb2.ModelID.FromString, + response_serializer=training__pb2.StatusResponse.SerializeToString, ), 'get_model_status': grpc.unary_unary_rpc_method_handler( servicer.get_model_status, - request_deserializer=training__pb2.ModelDetailsRequest.FromString, - response_serializer=training__pb2.ModelDetailsResponse.SerializeToString, - ), - 'update_model_access': grpc.unary_unary_rpc_method_handler( - servicer.update_model_access, - request_deserializer=training__pb2.UpdateModelRequest.FromString, - response_serializer=training__pb2.ModelDetailsResponse.SerializeToString, - ), - 'get_all_models': grpc.unary_unary_rpc_method_handler( - servicer.get_all_models, - request_deserializer=training__pb2.AccessibleModelsRequest.FromString, - response_serializer=training__pb2.AccessibleModelsResponse.SerializeToString, + request_deserializer=training__pb2.ModelID.FromString, + response_serializer=training__pb2.StatusResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -115,7 +170,8 @@ def add_ModelServicer_to_server(servicer, server): # This class is part of an EXPERIMENTAL API. class Model(object): - """Missing associated documentation comment in .proto file.""" + """Methods that the service provider must implement + """ @staticmethod def create_model(request, @@ -129,13 +185,13 @@ def create_model(request, timeout=None, metadata=None): return grpc.experimental.unary_unary(request, target, '/training.Model/create_model', - training__pb2.CreateModelRequest.SerializeToString, - training__pb2.ModelDetailsResponse.FromString, + training__pb2.NewModel.SerializeToString, + training__pb2.ModelID.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def delete_model(request, + def validate_model_price(request, target, options=(), channel_credentials=None, @@ -145,14 +201,14 @@ def delete_model(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/training.Model/delete_model', - training__pb2.UpdateModelRequest.SerializeToString, - training__pb2.ModelDetailsResponse.FromString, + return grpc.experimental.unary_unary(request, target, '/training.Model/validate_model_price', + training__pb2.ValidateRequest.SerializeToString, + training__pb2.PriceInBaseUnit.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def get_model_status(request, + def upload_and_validate(request_iterator, target, options=(), channel_credentials=None, @@ -162,14 +218,14 @@ def get_model_status(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/training.Model/get_model_status', - training__pb2.ModelDetailsRequest.SerializeToString, - training__pb2.ModelDetailsResponse.FromString, + return grpc.experimental.stream_unary(request_iterator, target, '/training.Model/upload_and_validate', + training__pb2.UploadInput.SerializeToString, + training__pb2.StatusResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def update_model_access(request, + def validate_model(request, target, options=(), channel_credentials=None, @@ -179,14 +235,14 @@ def update_model_access(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/training.Model/update_model_access', - training__pb2.UpdateModelRequest.SerializeToString, - training__pb2.ModelDetailsResponse.FromString, + return grpc.experimental.unary_unary(request, target, '/training.Model/validate_model', + training__pb2.ValidateRequest.SerializeToString, + training__pb2.StatusResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) @staticmethod - def get_all_models(request, + def train_model_price(request, target, options=(), channel_credentials=None, @@ -196,8 +252,59 @@ def get_all_models(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary(request, target, '/training.Model/get_all_models', - training__pb2.AccessibleModelsRequest.SerializeToString, - training__pb2.AccessibleModelsResponse.FromString, + return grpc.experimental.unary_unary(request, target, '/training.Model/train_model_price', + training__pb2.ModelID.SerializeToString, + training__pb2.PriceInBaseUnit.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def train_model(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Model/train_model', + training__pb2.ModelID.SerializeToString, + training__pb2.StatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def delete_model(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Model/delete_model', + training__pb2.ModelID.SerializeToString, + training__pb2.StatusResponse.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def get_model_status(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/training.Model/get_model_status', + training__pb2.ModelID.SerializeToString, + training__pb2.StatusResponse.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/snet/sdk/resources/proto/unixfs.proto b/snet/sdk/resources/proto/unixfs.proto deleted file mode 100644 index c190079..0000000 --- a/snet/sdk/resources/proto/unixfs.proto +++ /dev/null @@ -1,25 +0,0 @@ -syntax = "proto2"; -package unixfs.pb; - -message Data { - enum DataType { - Raw = 0; - Directory = 1; - File = 2; - Metadata = 3; - Symlink = 4; - HAMTShard = 5; - } - - required DataType Type = 1; - optional bytes Data = 2; - optional uint64 filesize = 3; - repeated uint64 blocksizes = 4; - - optional uint64 hashType = 5; - optional uint64 fanout = 6; -} - -message Metadata { - optional string MimeType = 1; -} \ No newline at end of file diff --git a/snet/sdk/resources/proto/unixfs_pb2.py b/snet/sdk/resources/proto/unixfs_pb2.py deleted file mode 100644 index 394f454..0000000 --- a/snet/sdk/resources/proto/unixfs_pb2.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: unixfs.proto -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cunixfs.proto\x12\tunixfs.pb\"\xdc\x01\n\x04\x44\x61ta\x12&\n\x04Type\x18\x01 \x02(\x0e\x32\x18.unixfs.pb.Data.DataType\x12\x0c\n\x04\x44\x61ta\x18\x02 \x01(\x0c\x12\x10\n\x08\x66ilesize\x18\x03 \x01(\x04\x12\x12\n\nblocksizes\x18\x04 \x03(\x04\x12\x10\n\x08hashType\x18\x05 \x01(\x04\x12\x0e\n\x06\x66\x61nout\x18\x06 \x01(\x04\"V\n\x08\x44\x61taType\x12\x07\n\x03Raw\x10\x00\x12\r\n\tDirectory\x10\x01\x12\x08\n\x04\x46ile\x10\x02\x12\x0c\n\x08Metadata\x10\x03\x12\x0b\n\x07Symlink\x10\x04\x12\r\n\tHAMTShard\x10\x05\"\x1c\n\x08Metadata\x12\x10\n\x08MimeType\x18\x01 \x01(\t') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'unixfs_pb2', _globals) -if _descriptor._USE_C_DESCRIPTORS == False: - DESCRIPTOR._options = None - _globals['_DATA']._serialized_start=28 - _globals['_DATA']._serialized_end=248 - _globals['_DATA_DATATYPE']._serialized_start=162 - _globals['_DATA_DATATYPE']._serialized_end=248 - _globals['_METADATA']._serialized_start=250 - _globals['_METADATA']._serialized_end=278 -# @@protoc_insertion_point(module_scope) diff --git a/snet/sdk/resources/proto/unixfs_pb2_grpc.py b/snet/sdk/resources/proto/unixfs_pb2_grpc.py deleted file mode 100644 index 2daafff..0000000 --- a/snet/sdk/resources/proto/unixfs_pb2_grpc.py +++ /dev/null @@ -1,4 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc - diff --git a/snet/sdk/service_client.py b/snet/sdk/service_client.py index 1423266..9294e59 100644 --- a/snet/sdk/service_client.py +++ b/snet/sdk/service_client.py @@ -1,72 +1,112 @@ import base64 -import collections import importlib import re import os from pathlib import Path +from typing import Any +from eth_typing import BlockNumber, ChecksumAddress import grpc +from hexbytes import HexBytes import web3 from eth_account.messages import defunct_hash_message from rfc3986 import urlparse -from snet.sdk.resources.root_certificate import certificate -from snet.sdk.utils.utils import RESOURCES_PATH, add_to_path, find_file_by_keyword - -import snet.sdk.generic_client_interceptor as generic_client_interceptor - -class _ClientCallDetails( - collections.namedtuple( - '_ClientCallDetails', - ('method', 'timeout', 'metadata', 'credentials')), - grpc.ClientCallDetails): - pass +from snet.sdk import generic_client_interceptor +from snet.sdk.account import Account +from snet.sdk.mpe.mpe_contract import MPEContract +from snet.sdk.mpe.payment_channel import PaymentChannel +from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider +from snet.sdk.payment_strategies import default_payment_strategy as strategy +from snet.sdk.resources.root_certificate import certificate +from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata +from snet.sdk.custom_typing import ModuleName, ServiceStub +from snet.sdk.utils.utils import (RESOURCES_PATH, add_to_path, + find_file_by_keyword) +from snet.sdk.training.training import Training +from snet.sdk.training.exceptions import NoTrainingException +from snet.sdk.utils.call_utils import create_intercept_call_func class ServiceClient: - def __init__(self, org_id, service_id, service_metadata, group, service_stub, payment_strategy, - options, mpe_contract, account, sdk_web3, pb2_module, payment_channel_provider): + def __init__( + self, + org_id: str, + service_id: str, + service_metadata: MPEServiceMetadata, + group: dict, + service_stubs: list[ServiceStub], + payment_strategy, + options: dict, + mpe_contract: MPEContract, + account: Account, + sdk_web3: web3.Web3, + pb2_module: ModuleName, + payment_channel_provider: PaymentChannelProvider, + path_to_pb_files: Path, + training_added: bool = False + ): self.org_id = org_id self.service_id = service_id - self.options = options - self.group = group self.service_metadata = service_metadata - + self.group = group self.payment_strategy = payment_strategy - self.expiry_threshold = self.group["payment"]["payment_expiration_threshold"] - self.__base_grpc_channel = self._get_grpc_channel() - self.grpc_channel = grpc.intercept_channel(self.__base_grpc_channel, - generic_client_interceptor.create(self._intercept_call)) - self.payment_channel_provider = payment_channel_provider - self.payment_channel_state_service_client = self._generate_payment_channel_state_service_client() - self.service = self._generate_grpc_stub(service_stub) - self.pb2_module = importlib.import_module(pb2_module) if isinstance(pb2_module, str) else pb2_module - self.payment_channels = [] - self.last_read_block = 0 + self.options = options + self.mpe_address = mpe_contract.contract.address self.account = account self.sdk_web3 = sdk_web3 - self.mpe_address = mpe_contract.contract.address + self.pb2_module = (importlib.import_module(pb2_module) + if isinstance(pb2_module, str) + else pb2_module) + self.payment_channel_provider = payment_channel_provider + self.path_to_pb_files = path_to_pb_files - def call_rpc(self, rpc_name: str, message_class: str, **kwargs): - rpc_method = getattr(self.service, rpc_name) + self.expiry_threshold: int = self.group["payment"]["payment_expiration_threshold"] + self.__base_grpc_channel = self._get_grpc_channel() + _intercept_call_func = create_intercept_call_func(self.payment_strategy.get_payment_metadata, self) + self.grpc_channel = grpc.intercept_channel( + self.__base_grpc_channel, + generic_client_interceptor.create(_intercept_call_func) + ) + self.service_stubs = service_stubs + self.payment_channel_state_service_client = self._generate_payment_channel_state_service_client() + self.payment_channels = [] + self.last_read_block: int = 0 + self.__training = Training(self, training_added) + + def call_rpc(self, rpc_name: str, message_class: str, **kwargs) -> Any: + service = self._get_service_stub(rpc_name) + if "model_id" in kwargs: + kwargs["model_id"] = self._get_training_model_id(kwargs["model_id"]) + rpc_method = getattr(service, rpc_name) request = getattr(self.pb2_module, message_class)(**kwargs) return rpc_method(request) def _get_payment_expiration_threshold_for_group(self): pass - def _generate_grpc_stub(self, service_stub): + def _get_service_stub(self, rpc_name: str) -> Any: + for service_stub in self.service_stubs: + grpc_stub = self._generate_grpc_stub(service_stub) + if hasattr(grpc_stub, rpc_name): + return grpc_stub + raise Exception(f"Service stub for {rpc_name} not found") + + def _generate_grpc_stub(self, service_stub: ServiceStub) -> Any: grpc_channel = self.__base_grpc_channel - disable_blockchain_operations = self.options.get("disable_blockchain_operations", False) - if disable_blockchain_operations is False: + disable_blockchain_operations: bool = self.options.get( + "disable_blockchain_operations", + False + ) + if not disable_blockchain_operations: grpc_channel = self.grpc_channel stub_instance = service_stub(grpc_channel) return stub_instance - def get_grpc_base_channel(self): + def get_grpc_base_channel(self) -> grpc.Channel: return self.__base_grpc_channel - def _get_grpc_channel(self): + def _get_grpc_channel(self) -> grpc.Channel: endpoint = self.options.get("endpoint", None) if endpoint is None: endpoint = self.service_metadata.get_all_endpoints_for_group(self.group["group_name"])[0] @@ -84,22 +124,10 @@ def _get_grpc_channel(self): else: raise ValueError('Unsupported scheme in service metadata ("{}")'.format(endpoint_object.scheme)) - def _get_service_call_metadata(self): - metadata = self.payment_strategy.get_payment_metadata(self) - return metadata - - def _intercept_call(self, client_call_details, request_iterator, request_streaming, - response_streaming): - metadata = [] - if client_call_details.metadata is not None: - metadata = list(client_call_details.metadata) - metadata.extend(self._get_service_call_metadata()) - client_call_details = _ClientCallDetails( - client_call_details.method, client_call_details.timeout, metadata, - client_call_details.credentials) - return client_call_details, request_iterator, None - - def _filter_existing_channels_from_new_payment_channels(self, new_payment_channels): + def _filter_existing_channels_from_new_payment_channels( + self, + new_payment_channels: list[PaymentChannel] + ) -> list[PaymentChannel]: new_channels_to_be_added = [] # need to change this logic ,use maps to manage channels so that we can easily navigate it @@ -115,99 +143,121 @@ def _filter_existing_channels_from_new_payment_channels(self, new_payment_channe return new_channels_to_be_added - def load_open_channels(self): + def load_open_channels(self) -> list[PaymentChannel]: current_block_number = self.sdk_web3.eth.block_number payment_address = self.group["payment"]["payment_address"] group_id = base64.b64decode(str(self.group["group_id"])) - new_payment_channels = self.payment_channel_provider.get_past_open_channels(self.account, payment_address, - group_id, - self.payment_channel_state_service_client) - self.payment_channels = self.payment_channels + \ - self._filter_existing_channels_from_new_payment_channels(new_payment_channels) + new_payment_channels = ( + self.payment_channel_provider.get_past_open_channels( + self.account, payment_address, + group_id, self.payment_channel_state_service_client + ) + ) + filter_new_channels = self._filter_existing_channels_from_new_payment_channels(new_payment_channels) + self.payment_channels = (self.payment_channels + filter_new_channels) self.last_read_block = current_block_number return self.payment_channels - def get_current_block_number(self): + def get_current_block_number(self) -> BlockNumber: return self.sdk_web3.eth.block_number - def update_channel_states(self): + def update_channel_states(self) -> list[PaymentChannel]: for channel in self.payment_channels: channel.sync_state() return self.payment_channels - def default_channel_expiration(self): + def default_channel_expiration(self) -> int: current_block_number = self.sdk_web3.eth.get_block("latest").number return current_block_number + self.expiry_threshold - def _generate_payment_channel_state_service_client(self): + def _generate_payment_channel_state_service_client(self) -> Any: grpc_channel = self.__base_grpc_channel with add_to_path(str(RESOURCES_PATH.joinpath("proto"))): - state_service_pb2_grpc = importlib.import_module("state_service_pb2_grpc") - return state_service_pb2_grpc.PaymentChannelStateServiceStub(grpc_channel) + state_service = importlib.import_module("state_service_pb2_grpc") + return state_service.PaymentChannelStateServiceStub(grpc_channel) - def open_channel(self, amount, expiration): + def open_channel(self, amount: int, expiration: int) -> PaymentChannel: payment_address = self.group["payment"]["payment_address"] group_id = base64.b64decode(str(self.group["group_id"])) - return self.payment_channel_provider.open_channel(self.account, amount, expiration, payment_address, - group_id, self.payment_channel_state_service_client) + return self.payment_channel_provider.open_channel( + self.account, amount, expiration, payment_address, + group_id, self.payment_channel_state_service_client + ) - def deposit_and_open_channel(self, amount, expiration): + def deposit_and_open_channel(self, amount: int, + expiration: int) -> PaymentChannel: payment_address = self.group["payment"]["payment_address"] group_id = base64.b64decode(str(self.group["group_id"])) - return self.payment_channel_provider.deposit_and_open_channel(self.account, amount, expiration, - payment_address, group_id, - self.payment_channel_state_service_client) + return self.payment_channel_provider.deposit_and_open_channel( + self.account, amount, expiration, payment_address, + group_id, self.payment_channel_state_service_client + ) - def get_price(self): + def get_price(self) -> int: return self.group["pricing"][0]["price_in_cogs"] - def generate_signature(self, message): - signature = bytes(self.sdk_web3.eth.account.signHash(defunct_hash_message(message), - self.account.signer_private_key).signature) - - return signature + def generate_signature(self, message: bytes) -> bytes: + return bytes(self.sdk_web3.eth.account.signHash( + defunct_hash_message(message), self.account.signer_private_key + ).signature) - def generate_training_signature(self, text: str, address, block_number): + def generate_training_signature(self, text: str, address: str, + block_number: BlockNumber) -> HexBytes: + address = web3.Web3.to_checksum_address(address) message = web3.Web3.solidity_keccak( ["string", "address", "uint256"], [text, address, block_number] ) - return self.sdk_web3.eth.account.signHash(defunct_hash_message(message), - self.account.signer_private_key).signature - - def get_free_call_config(self): - return self.options['email'], self.options['free_call_auth_token-bin'], self.options[ - 'free-call-token-expiry-block'] - - def get_service_details(self): - return self.org_id, self.service_id, self.group["group_id"], \ - self.service_metadata.get_all_endpoints_for_group(self.group["group_name"])[0] - - def get_concurrency_flag(self): + return self.sdk_web3.eth.account.signHash( + defunct_hash_message(message), self.account.signer_private_key + ).signature + + def get_free_call_config(self) -> tuple[str, str, int]: + return (self.options['email'], + self.options['free_call_auth_token-bin'], + self.options['free-call-token-expiry-block']) + + def get_service_details(self) -> tuple[str, str, str, str]: + return (self.org_id, + self.service_id, + self.group["group_id"], + self.service_metadata.get_all_endpoints_for_group( + self.group["group_name"] + )[0]) + + @property + def training(self) -> Training: + if not self.__training.is_enabled: + raise NoTrainingException(self.org_id, self.service_id) + return self.__training + + def _get_training_model_id(self, model_id: str) -> Any: + return self.training.get_model_id_object(model_id) + + def get_concurrency_flag(self) -> bool: return self.options.get('concurrency', True) - def get_concurrency_token_and_channel(self): + def get_concurrency_token_and_channel(self) -> tuple[str, PaymentChannel]: return self.payment_strategy.get_concurrency_token_and_channel(self) - def set_concurrency_token_and_channel(self, token, channel): + def set_concurrency_token_and_channel(self, token: str, + channel: PaymentChannel) -> None: self.payment_strategy.concurrency_token = token self.payment_strategy.channel = channel - def get_path_to_pb_files(self, org_id: str, service_id: str) -> str: - client_libraries_base_dir_path = Path("~").expanduser().joinpath(".snet") - path_to_pb_files = f"{client_libraries_base_dir_path}/{org_id}/{service_id}/python/" - return path_to_pb_files - - def get_services_and_messages_info(self): + def get_services_and_messages_info(self) -> tuple[dict, dict]: # Get proto file filepath and open it - path_to_pb_files = self.get_path_to_pb_files(self.org_id, self.service_id) - proto_file_name = find_file_by_keyword(path_to_pb_files, keyword=".proto") - proto_filepath = os.path.join(path_to_pb_files, proto_file_name) + proto_file_name = find_file_by_keyword(directory=self.path_to_pb_files, + keyword=".proto", exclude=["training"]) + proto_filepath = os.path.join(self.path_to_pb_files, proto_file_name) with open(proto_filepath, 'r') as file: proto_content = file.read() - # Regular expression patterns to match services, methods, messages, and fields + # Regular expression patterns to match services, methods, + # messages and fields service_pattern = re.compile(r'service\s+(\w+)\s*{') - rpc_pattern = re.compile(r'rpc\s+(\w+)\s*\((\w+)\)\s+returns\s+\((\w+)\)') + rpc_pattern = re.compile( + r'rpc\s+(\w+)\s*\((\w+)\)\s+returns\s+\((\w+)\)' + ) message_pattern = re.compile(r'message\s+(\w+)\s*{') field_pattern = re.compile(r'\s*(\w+)\s+(\w+)\s*=\s*\d+\s*;') @@ -231,7 +281,8 @@ def get_services_and_messages_info(self): method_name = rpc_match.group(1) input_type = rpc_match.group(2) output_type = rpc_match.group(3) - services[current_service].append((method_name, input_type, output_type)) + services[current_service].append((method_name, input_type, + output_type)) # Match a message definition message_match = message_pattern.search(line) @@ -250,7 +301,7 @@ def get_services_and_messages_info(self): return services, messages - def get_services_and_messages_info_as_pretty_string(self): + def get_services_and_messages_info_as_pretty_string(self) -> str: services, messages = self.get_services_and_messages_info() string_output = "" @@ -258,7 +309,9 @@ def get_services_and_messages_info_as_pretty_string(self): for service, methods in services.items(): string_output += f"Service: {service}\n" for method_name, input_type, output_type in methods: - string_output += f" Method: {method_name}, Input: {input_type}, Output: {output_type}\n" + string_output += (f" Method: {method_name}," + f" Input: {input_type}," + f" Output: {output_type}\n") # Prettify the messages and their fields for message, fields in messages.items(): diff --git a/snet/sdk/storage_provider/service_metadata.py b/snet/sdk/storage_provider/service_metadata.py index e39ca46..a49c573 100644 --- a/snet/sdk/storage_provider/service_metadata.py +++ b/snet/sdk/storage_provider/service_metadata.py @@ -536,13 +536,13 @@ def add_description(self): } -def load_mpe_service_metadata(f): +def load_mpe_service_metadata(f) -> MPEServiceMetadata: metadata = MPEServiceMetadata() metadata.load(f) return metadata -def mpe_service_metadata_from_json(j): +def mpe_service_metadata_from_json(j) -> MPEServiceMetadata: metadata = MPEServiceMetadata() metadata.set_from_json(j) return metadata diff --git a/snet/sdk/storage_provider/storage_provider.py b/snet/sdk/storage_provider/storage_provider.py index 852fbc3..cf698af 100644 --- a/snet/sdk/storage_provider/storage_provider.py +++ b/snet/sdk/storage_provider/storage_provider.py @@ -4,7 +4,7 @@ from snet.sdk.utils.ipfs_utils import get_ipfs_client, get_from_ipfs_and_checkhash from snet.sdk.utils.utils import bytesuri_to_hash, safe_extract_proto -from snet.sdk.storage_provider.service_metadata import mpe_service_metadata_from_json +from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata, mpe_service_metadata_from_json class StorageProvider(object): def __init__(self, config, registry_contract): @@ -29,20 +29,34 @@ def fetch_org_metadata(self,org_id): return org_metadata - def fetch_service_metadata(self,org_id,service_id): + def fetch_service_metadata(self, org_id: str, + service_id: str) -> MPEServiceMetadata: org = web3.Web3.to_bytes(text=org_id).ljust(32, b"\0") service = web3.Web3.to_bytes(text=service_id).ljust(32, b"\0") - found, _, service_metadata_uri = self._registry_contract.functions.getServiceRegistrationById(org, service).call() + found, _, service_metadata_uri = ( + self._registry_contract.functions.getServiceRegistrationById( + org, + service + ).call() + ) if found is not True: - raise Exception('No service "{}" found in organization "{}"'.format(service_id, org_id)) + raise Exception(f"No service '{service_id}' " + f"found in organization '{org_id}'") - service_provider_type, service_metadata_hash = bytesuri_to_hash(service_metadata_uri) + service_provider_type, service_metadata_hash = bytesuri_to_hash( + s=service_metadata_uri + ) if service_provider_type == "ipfs": - service_metadata_json = get_from_ipfs_and_checkhash(self._ipfs_client, service_metadata_hash) + service_metadata_json = get_from_ipfs_and_checkhash( + self._ipfs_client, + service_metadata_hash + ) else: - service_metadata_json, _ = self.lighthouse_client.download(service_metadata_hash) + service_metadata_json, _ = self.lighthouse_client.download( + cid=service_metadata_hash + ) service_metadata = mpe_service_metadata_from_json(service_metadata_json) return service_metadata diff --git a/snet/sdk/training/exceptions.py b/snet/sdk/training/exceptions.py new file mode 100644 index 0000000..421aee0 --- /dev/null +++ b/snet/sdk/training/exceptions.py @@ -0,0 +1,31 @@ +from grpc import RpcError + + +class WrongDatasetException(Exception): + def __init__(self, errors: list[str]): + self.errors = errors + exception_msg = "Dataset check failed:\n" + for check in errors: + exception_msg += f"\t{check}\n" + super().__init__(exception_msg) + + +class WrongMethodException(Exception): + def __init__(self, method_name: str): + super().__init__(f"Method with name {method_name} not found!") + + +class NoTrainingException(Exception): + def __init__(self, org_id: str, service_id: str): + super().__init__(f"Training is not implemented for the service with org_id={org_id} and service_id={service_id}!") + + +class GRPCException(RpcError): + def __init__(self, error: RpcError): + super().__init__(f"An error occurred during the grpc call: {error}.") + + +class NoSuchModelException(Exception): + def __init__(self, model_id: str): + super().__init__(f"Model with id {model_id} not found!") + diff --git a/snet/sdk/training/responses.py b/snet/sdk/training/responses.py new file mode 100644 index 0000000..b2b01bc --- /dev/null +++ b/snet/sdk/training/responses.py @@ -0,0 +1,195 @@ +from enum import Enum +from typing import Any + + +class ModelMethodMessage(Enum): + CreateModel = "create_model" + ValidateModelPrice = "validate_model_price" + TrainModelPrice = "train_model_price" + DeleteModel = "delete_model" + GetTrainingMetadata = "get_training_metadata" + GetAllModels = "get_all_models" + GetModel = "get_model" + UpdateModel = "update_model" + GetMethodMetadata = "get_method_metadata" + UploadAndValidate = "upload_and_validate" + ValidateModel = "validate_model" + TrainModel = "train_model" + + +class ModelStatus(Enum): + CREATED = 0 + VALIDATING = 1 + VALIDATED = 2 + TRAINING = 3 + READY_TO_USE = 4 + ERRORED = 5 + DELETED = 6 + + +def to_string(obj: Any): + res_str = "" + for key, value in obj.__dict__.items(): + key = key.split("__")[1].replace("_", " ") + res_str += f"{key}: {value}\n" + return res_str + + +class Model: + def __init__(self, model_response): + self.__model_id = model_response.model_id + self.__status = ModelStatus(model_response.status) + self.__created_date = model_response.created_date + self.__updated_date = model_response.updated_date + self.__name = model_response.name + self.__description = model_response.description + self.__grpc_method_name = model_response.grpc_method_name + self.__grpc_service_name = model_response.grpc_service_name + self.__address_list = model_response.address_list + self.__is_public = model_response.is_public + self.__training_data_link = model_response.training_data_link + self.__created_by_address = model_response.created_by_address + self.__updated_by_address = model_response.updated_by_address + + def __str__(self): + return to_string(self) + + @property + def model_id(self): + return self.__model_id + + @property + def status(self): + return self.__status + + @property + def created_date(self): + return self.__created_date + + @property + def updated_date(self): + return self.__updated_date + + @property + def name(self): + return self.__name + + @property + def description(self): + return self.__description + + @property + def grpc_method_name(self): + return self.__grpc_method_name + + @property + def grpc_service_name(self): + return self.__grpc_service_name + + @property + def address_list(self): + return self.__address_list + + @property + def is_public(self): + return self.__is_public + + @property + def training_data_link(self): + return self.__training_data_link + + @property + def created_by_address(self): + return self.__created_by_address + + @property + def updated_by_address(self): + return self.__updated_by_address + + +class TrainingMetadata: + def __init__(self, + training_enabled: bool, + training_in_proto: bool, + training_methods: Any): + + self.__training_enabled = training_enabled + self.__training_in_proto = training_in_proto + self.__training_methods = {} + + services_methods = dict(training_methods) + for k, v in services_methods.items(): + self.__training_methods[k] = [value.string_value for value in v.values] + + def __str__(self): + return to_string(self) + + @property + def training_enabled(self): + return self.__training_enabled + + @property + def training_in_proto(self): + return self.__training_in_proto + + @property + def training_methods(self): + return self.__training_methods + + +class MethodMetadata: + def __init__(self, + default_model_id: str, + max_models_per_user: int, + dataset_max_size_mb: int, + dataset_max_count_files: int, + dataset_max_size_single_file_mb: int, + dataset_files_type: str, + dataset_type: str, + dataset_description: str): + + self.__default_model_id = default_model_id + self.__max_models_per_user = max_models_per_user + self.__dataset_max_size_mb = dataset_max_size_mb + self.__dataset_max_count_files = dataset_max_count_files + self.__dataset_max_size_single_file_mb = dataset_max_size_single_file_mb + self.__dataset_files_type = dataset_files_type + self.__dataset_type = dataset_type + self.__dataset_description = dataset_description + + def __str__(self): + return to_string(self) + + @property + def default_model_id(self): + return self.__default_model_id + + @property + def max_models_per_user(self): + return self.__max_models_per_user + + @property + def dataset_max_size_mb(self): + return self.__dataset_max_size_mb + + @property + def dataset_max_count_files(self): + return self.__dataset_max_count_files + + @property + def dataset_max_size_single_file_mb(self): + return self.__dataset_max_size_single_file_mb + + @property + def dataset_files_type(self): + return self.__dataset_files_type + + @property + def dataset_type(self): + return self.__dataset_type + + @property + def dataset_description(self): + return self.__dataset_description + + diff --git a/snet/sdk/training/training.py b/snet/sdk/training/training.py index c5ede80..100dc1a 100644 --- a/snet/sdk/training/training.py +++ b/snet/sdk/training/training.py @@ -1,157 +1,399 @@ -import enum -import importlib -from urllib.parse import urlparse - +from pathlib import PurePath, Path +import os +from zipfile import ZipFile +from typing import Any import grpc -import web3 +import importlib -from snet.sdk.resources.root_certificate import certificate +from snet.sdk import generic_client_interceptor +from snet.sdk.payment_strategies.training_payment_strategy import TrainingPaymentStrategy +from snet.sdk.utils.call_utils import create_intercept_call_func from snet.sdk.utils.utils import add_to_path, RESOURCES_PATH +from snet.sdk.training.exceptions import ( + WrongDatasetException, + WrongMethodException, + GRPCException, + NoSuchModelException +) +from snet.sdk.training.responses import ( + ModelStatus, + Model, + TrainingMetadata, + MethodMetadata, + ModelMethodMessage +) -# for local debug -# from snet.snet_cli.resources.proto import training_pb2_grpc -# from snet.snet_cli.resources.proto import training_pb2 +class Training: + def __init__(self, service_client, training_added=False): + with add_to_path(str(RESOURCES_PATH.joinpath("proto"))): + self.training_daemon = importlib.import_module("training_daemon_pb2") + self.training_daemon_grpc = importlib.import_module("training_daemon_pb2_grpc") + self.training = importlib.import_module("training_pb2") + + self.service_client = service_client + self.is_enabled = training_added and self._check_training() + self.payment_strategy = TrainingPaymentStrategy() + def get_model_id_object(self, model_id: str) -> Any: + return self.training.ModelID(model_id=model_id) -# from daemon code -class ModelMethodMessage(enum.Enum): - CreateModel = "__CreateModel" - GetModelStatus = "__GetModelStatus" - UpdateModelAccess = "__UpdateModelAccess" - GetAllModels = "__UpdateModelAccess" - DeleteModel = "__GetModelStatus" + """FREE METHODS TO CALL""" + def create_model(self, method_name: str, + model_name: str, + model_description: str="", + is_public_accessible: bool=False, + addresses_with_access: list[str]=None) -> Model: + if addresses_with_access is None: + addresses_with_access = [] -class TrainingModel: + service_name, method_name = self._check_method_name(method_name) + new_model = self.training.NewModel(name=model_name, + description=model_description, + grpc_method_name=method_name, + grpc_service_name=service_name, + is_public=is_public_accessible, + address_list=addresses_with_access, + organization_id="", + service_id="", + group_id="") + auth_details = self._get_auth_details(ModelMethodMessage.CreateModel) + create_model_request = self.training_daemon.NewModelRequest(authorization=auth_details, + model=new_model) + response = self._call_method("create_model", + request_data=create_model_request) + model = Model(response) - def __init__(self): - with add_to_path(str(RESOURCES_PATH.joinpath("proto"))): - self.training_pb2 = importlib.import_module("training_pb2") + return model - with add_to_path(str(RESOURCES_PATH.joinpath("proto"))): - self.training_pb2_grpc = importlib.import_module("training_pb2_grpc") + def validate_model_price(self, model_id: str) -> int: + auth_details = self._get_auth_details(ModelMethodMessage.ValidateModelPrice) + validate_model_price_request = self.training_daemon.AuthValidateRequest( + authorization=auth_details, + model_id=model_id, + training_data_link="" + ) - def _invoke_model(self, service_client, msg: ModelMethodMessage): - org_id, service_id, group_id, daemon_endpoint = service_client.get_service_details() + try: + response = self._call_method("validate_model_price", + request_data=validate_model_price_request) + except GRPCException as e: + if "unable to access model" in str(e): + raise NoSuchModelException(model_id) + else: + raise e - endpoint_object = urlparse(daemon_endpoint) - if endpoint_object.port is not None: - channel_endpoint = endpoint_object.hostname + ":" + str(endpoint_object.port) - else: - channel_endpoint = endpoint_object.hostname - - if endpoint_object.scheme == "http": - print("creating http channel: ", channel_endpoint) - channel = grpc.insecure_channel(channel_endpoint) - elif endpoint_object.scheme == "https": - channel = grpc.secure_channel(channel_endpoint, - grpc.ssl_channel_credentials(root_certificates=certificate)) + return response.price + + def train_model_price(self, model_id: str) -> int: + auth_details = self._get_auth_details(ModelMethodMessage.TrainModelPrice) + common_request = self.training_daemon.CommonRequest(authorization=auth_details, + model_id=model_id) + try: + response = self._call_method("train_model_price", + request_data=common_request) + except GRPCException as e: + if "unable to access model" in str(e): + raise NoSuchModelException(model_id) + else: + raise e + + return response.price + + def delete_model(self, model_id: str) -> ModelStatus: + auth_details = self._get_auth_details(ModelMethodMessage.DeleteModel) + common_request = self.training_daemon.CommonRequest(authorization=auth_details, + model_id=model_id) + try: + response = self._call_method("delete_model", + request_data=common_request) + except GRPCException as e: + if "unable to access model" in str(e): + raise NoSuchModelException(model_id) + else: + raise e + + return ModelStatus(response.status) + + def get_training_metadata(self) -> TrainingMetadata: + empty_request = self.training_daemon.google_dot_protobuf_dot_empty__pb2.Empty() + response = self._call_method("get_training_metadata", + request_data=empty_request) + + training_metadata = TrainingMetadata(training_enabled=response.trainingEnabled, + training_in_proto=response.trainingInProto, + training_methods=response.trainingMethods) + + return training_metadata + + def get_all_models(self, statuses: list[ModelStatus]=None, + is_public: bool=None, + grpc_method_name: str="", + grpc_service_name: str="", + model_name: str="", + created_by_address: str="") -> list[Model]: + if statuses is not None: + statuses = [getattr(self.training.Status, status.value) for status in statuses] + auth_details = self._get_auth_details(ModelMethodMessage.GetAllModels) + all_models_request = self.training_daemon.AllModelsRequest( + authorization=auth_details, + statuses=statuses, + is_public=is_public, + grpc_method_name=grpc_method_name, + grpc_service_name=grpc_service_name, + name=model_name, + created_by_address=created_by_address, + page_size=0, + page=0 + ) + response = self._call_method("get_all_models", + request_data=all_models_request) + models = [] + for model in response.list_of_models: + models.append(Model(model)) + + return models + + def get_model(self, model_id: str) -> Model: + auth_details = self._get_auth_details(ModelMethodMessage.GetModel) + common_request = self.training_daemon.CommonRequest(authorization=auth_details, + model_id=model_id) + try: + response = self._call_method("get_model", + request_data=common_request) + except GRPCException as e: + if "unable to access model" in str(e): + raise NoSuchModelException(model_id) + else: + raise e + model = Model(response) + + return model + + def get_method_metadata(self, method_name: str, model_id: str= "") -> MethodMetadata: + + if model_id: + requirements_request = self.training_daemon.MethodMetadataRequest( + grpc_method_name="", + grpc_service_name="", + model_id=model_id + ) else: - raise ValueError('Unsupported scheme in service metadata ("{}")'.format(endpoint_object.scheme)) - - current_block_number = service_client.get_current_block_number() - signature = service_client.generate_training_signature(msg.value, web3.Web3.to_checksum_address( - service_client.account.address), current_block_number) - auth_req = self.training_pb2.AuthorizationDetails(signature=bytes(signature), - current_block=current_block_number, - signer_address=service_client.account.address, - message=msg.value) - return auth_req, channel - - # params from AI-service: status, model_id - # params pass to daemon: grpc_service_name, grpc_method_name, address_list, - # description, model_name, training_data_link, is_public_accessible - def create_model(self, service_client, grpc_method_name: str, - model_name: str, description: str = '', - training_data_link: str = '', grpc_service_name='service', - is_publicly_accessible=False, address_list: list[str] = None): - if address_list is None: - address_list = [] + service_name, method_name = self._check_method_name(method_name) + requirements_request = self.training_daemon.MethodMetadataRequest( + grpc_method_name=method_name, + grpc_service_name=service_name, + model_id="" + ) + response = self._call_method("get_method_metadata", + request_data=requirements_request) + + method_metadata = MethodMetadata(default_model_id = response.default_model_id, + max_models_per_user = response.max_models_per_user, + dataset_max_size_mb = response.dataset_max_size_mb, + dataset_max_count_files = response.dataset_max_count_files, + dataset_max_size_single_file_mb = response.dataset_max_size_single_file_mb, + dataset_files_type = response.dataset_files_type, + dataset_type = response.dataset_type, + dataset_description = response.dataset_description) + return method_metadata + + def update_model(self, model_id: str, + model_name: str=None, + description: str=None, + addresses_with_access: list[str]=None) -> Model: + auth_details = self._get_auth_details(ModelMethodMessage.UpdateModel) + update_model_request = self.training_daemon.UpdateModelRequest( + authorization=auth_details,model_id=model_id, + model_name=model_name, + description=description, + address_list=addresses_with_access + ) + try: - auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.CreateModel) - model_details = self.training_pb2.ModelDetails(grpc_method_name=grpc_method_name, - description=description, - training_data_link=training_data_link, - grpc_service_name=grpc_service_name, - model_name=model_name, address_list=address_list, - is_publicly_accessible=is_publicly_accessible) - stub = self.training_pb2_grpc.ModelStub(channel) - response = stub.create_model( - self.training_pb2.CreateModelRequest(authorization=auth_req, model_details=model_details)) - return response - except Exception as e: - print("Exception: ", e) - return e + response = self._call_method("update_model", + request_data=update_model_request) + except GRPCException as e: + if "unable to access model" in str(e): + raise NoSuchModelException(model_id) + else: + raise e + + model = Model(response) + + return model + + """PAID METHODS TO CALL""" + + def upload_and_validate(self, model_id: str, + zip_path: str | Path | PurePath, price: int) -> ModelStatus: + f = open(zip_path, 'rb') + file_size = os.path.getsize(zip_path) + + file_name = os.path.basename(zip_path) + file_size = file_size + batch_size = 1024*1024 + batch_count = file_size // batch_size + if file_size % batch_size != 0: + batch_count += 1 + + self._check_dataset(model_id, zip_path) + + auth_details = self._get_auth_details(ModelMethodMessage.UploadAndValidate) + + def request_iter(file): + batch = file.read(batch_size) + batch_number = 1 + while batch: + upload_input = self.training.UploadInput( + model_id = model_id, + data = batch, + file_name = file_name, + file_size = file_size, + batch_size = batch_size, + batch_number = batch_number, + batch_count = batch_count + ) + yield self.training_daemon.UploadAndValidateRequest( + authorization=auth_details, + upload_input=upload_input + ) + batch = file.read(batch_size) + batch_number += 1 + + self.payment_strategy.set_price(price) + self.payment_strategy.set_model_id(model_id) - # params from AI-service: status - # params to daemon: grpc_service_name, grpc_method_name, model_id - def get_model_status(self, service_client, model_id: str, grpc_method_name: str, grpc_service_name='service'): try: - auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.GetModelStatus) - model_details = self.training_pb2.ModelDetails(grpc_method_name=grpc_method_name, - grpc_service_name=grpc_service_name, - model_id=str(model_id)) - stub = self.training_pb2_grpc.ModelStub(channel) - response = stub.get_model_status( - self.training_pb2.ModelDetailsRequest(authorization=auth_req, model_details=model_details)) - return response - except Exception as e: - print("Exception: ", e) - return e - - # params from AI-service: status - # params to daemon: grpc_service_name, grpc_method_name, model_id - def delete_model(self, service_client, model_id: str, - grpc_method_name: str, grpc_service_name='service'): + response = self._call_method("upload_and_validate", + request_data=request_iter(f), + paid=True) + except GRPCException as e: + if "unable to access model" in str(e): + raise NoSuchModelException(model_id) + else: + raise e + finally: + f.close() + + return ModelStatus(response.status) + + def train_model(self, model_id: str, price: int) -> ModelStatus: + auth_details = self._get_auth_details(ModelMethodMessage.TrainModel) + common_request = self.training_daemon.CommonRequest(authorization=auth_details, + model_id=model_id) + self.payment_strategy.set_price(price) + self.payment_strategy.set_model_id(model_id) + try: - auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.DeleteModel) - model_details = self.training_pb2.ModelDetails(grpc_method_name=grpc_method_name, - grpc_service_name=grpc_service_name, - model_id=str(model_id)) - stub = self.training_pb2_grpc.ModelStub(channel) - response = stub.delete_model( - self.training_pb2.UpdateModelRequest(authorization=auth_req, update_model_details=model_details)) - return response - except Exception as e: - print("Exception: ", e) - return e - - # params from AI-service: None - # params to daemon: grpc_service_name, grpc_method_name, model_id, address_list, is_public, model_name, desc - # all params required - def update_model_access(self, service_client, model_id: str, grpc_method_name: str, - model_name: str, is_public: bool, - description: str, grpc_service_name: str = 'service', - address_list: list[str] = None): + response = self._call_method("train_model", + request_data=common_request, + paid=True) + except GRPCException as e: + if "unable to access model" in str(e): + raise NoSuchModelException(model_id) + else: + raise e + + return ModelStatus(response.status) + + """PRIVATE METHODS""" + + def _call_method(self, method_name: str, + request_data, + paid=False) -> Any: try: - auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.UpdateModelAccess) - model_details = self.training_pb2.ModelDetails(grpc_method_name=grpc_method_name, - description=description, - grpc_service_name=grpc_service_name, - address_list=address_list, - is_publicly_accessible=is_public, - model_name=model_name, - model_id=str(model_id)) - stub = self.training_pb2_grpc.ModelStub(channel) - response = stub.update_model_access( - self.training_pb2.UpdateModelRequest(authorization=auth_req, update_model_details=model_details)) + stub = self._get_training_stub(paid=paid) + response = getattr(stub, method_name)(request_data) return response - except Exception as e: - print("Exception: ", e) - return e + except grpc.RpcError as e: + raise GRPCException(e) + + def _get_training_stub(self, paid=False) -> Any: + grpc_channel = self.service_client.get_grpc_base_channel() + if paid: + grpc_channel = self._get_grpc_channel(grpc_channel) + return self.training_daemon_grpc.DaemonStub(grpc_channel) + + def _get_auth_details(self, method_msg: ModelMethodMessage) -> Any: + current_block_number = self.service_client.get_current_block_number() + address = self.service_client.account.address + signature = self.service_client.generate_training_signature(method_msg.value, + address, + current_block_number) + auth_details = self.training_daemon.AuthorizationDetails( + signature=bytes(signature), + current_block=current_block_number, + signer_address=address, + message=method_msg.value + ) + return auth_details - # params from AI-service: None - # params to daemon: grpc_service_name, grpc_method_name - def get_all_models(self, service_client, grpc_method_name: str, grpc_service_name='service'): + def _check_method_name(self, method_name: str) -> tuple[str, str]: + services_methods, _ = self.service_client.get_services_and_messages_info() + for service, methods in services_methods.items(): + for method in methods: + if method[0] == method_name: + return service, method[0] + raise WrongMethodException(method_name) + + def _check_training(self) -> bool: try: - auth_req, channel = self._invoke_model(service_client, ModelMethodMessage.GetAllModels) - stub = self.training_pb2_grpc.ModelStub(channel) - response = stub.get_all_models( - self.training_pb2.AccessibleModelsRequest(authorization=auth_req, - grpc_service_name=grpc_service_name, - grpc_method_name=grpc_method_name)) - return response - except Exception as e: - print("Exception: ", e) - return e + service_methods = self.get_training_metadata().training_methods + except GRPCException as e: + return False + if len(service_methods.keys()) == 0: + return False + n_methods = 0 + for service, methods in service_methods.items(): + n_methods += len(methods) + if n_methods == 0: + return False + else: + return True + + def _check_dataset(self, model_id: str, zip_path: str | Path | PurePath) -> None: + method_metadata = self.get_method_metadata("", model_id) + max_size_mb = method_metadata.dataset_max_size_mb + max_count_files = method_metadata.dataset_max_count_files + max_size_mb_single = method_metadata.dataset_max_size_single_file_mb + file_types = method_metadata.dataset_files_type + file_types = [i for i in file_types.replace(' ', '').split(',')] + if file_types[0] == '': + file_types = [] + + failed_checks = [] + zip_file = ZipFile(zip_path) + + if os.path.getsize(zip_path) > max_size_mb * 1024 * 1024 > 0: + failed_checks.append(f"Too big dataset size: " + f"{os.path.getsize(zip_path)//1024/1024} MB > {max_size_mb} MB") + + files_list = zip_file.infolist() + if len(files_list) > max_count_files > 0: + failed_checks.append(f"Too many files: {len(files_list)} > {max_count_files}") + + for file_info in files_list: + _, extension = os.path.splitext(file_info.filename) + extension = extension[1:] if extension.startswith('.') else extension + if len(file_types) > 0 and extension not in file_types: + failed_checks.append(f"Wrong file type: `{extension}` in file: " + f"`{file_info.filename}`. Allowed file types: " + f"{', '.join(file_types)}") + if file_info.file_size > max_size_mb_single * 1024 * 1024 > 0: + failed_checks.append(f"Too big file `{file_info.filename}` size: " + f"{file_info.file_size // 1024 / 1024} MB > " + f"{max_size_mb_single} MB") + + if len(failed_checks) > 0: + raise WrongDatasetException(failed_checks) + + def _get_grpc_channel(self, base_channel: grpc.Channel) -> grpc.Channel: + intercept_call_func = create_intercept_call_func(self.payment_strategy.get_payment_metadata, + self.service_client) + grpc_channel = grpc.intercept_channel( + base_channel, + generic_client_interceptor.create(intercept_call_func) + ) + return grpc_channel diff --git a/snet/sdk/utils/call_utils.py b/snet/sdk/utils/call_utils.py new file mode 100644 index 0000000..a1a6669 --- /dev/null +++ b/snet/sdk/utils/call_utils.py @@ -0,0 +1,24 @@ +import collections +import grpc + + +class _ClientCallDetails( + collections.namedtuple( + '_ClientCallDetails', + ('method', 'timeout', 'metadata', 'credentials')), + grpc.ClientCallDetails): + pass + + +def create_intercept_call_func(get_metadata_func: callable, service_client) -> callable: + def intercept_call(client_call_details, request_iterator, request_streaming, response_streaming): + metadata = [] + if client_call_details.metadata is not None: + metadata = list(client_call_details.metadata) + metadata.extend(get_metadata_func(service_client)) + client_call_details = _ClientCallDetails( + client_call_details.method, client_call_details.timeout, metadata, + client_call_details.credentials) + return client_call_details, request_iterator, None + + return intercept_call diff --git a/snet/sdk/utils/utils.py b/snet/sdk/utils/utils.py index 21c4369..b036be8 100644 --- a/snet/sdk/utils/utils.py +++ b/snet/sdk/utils/utils.py @@ -9,6 +9,7 @@ import io import web3 +from eth_typing import BlockNumber from grpc_tools.protoc import main as protoc from snet import sdk @@ -43,7 +44,13 @@ def bytes32_to_str(b): return b.rstrip(b"\0").decode("utf-8") -def compile_proto(entry_path, codegen_dir, proto_file=None, target_language="python"): +def compile_proto( + entry_path: Path, + codegen_dir: Path, + proto_file: str | None = None, + target_language: str = "python", + add_training: bool = False +) -> bool: try: if not os.path.exists(codegen_dir): os.makedirs(codegen_dir) @@ -54,31 +61,26 @@ def compile_proto(entry_path, codegen_dir, proto_file=None, target_language="pyt "-I{}".format(proto_include) ] + if add_training: + training_include = RESOURCES_PATH.joinpath("proto", "training") + compiler_args.append("-I{}".format(training_include)) + if target_language == "python": compiler_args.insert(0, "protoc") compiler_args.append("--python_out={}".format(codegen_dir)) compiler_args.append("--grpc_python_out={}".format(codegen_dir)) compiler = protoc - elif target_language == "nodejs": - protoc_node_compiler_path = Path( - RESOURCES_PATH.joinpath("node_modules").joinpath("grpc-tools").joinpath("bin").joinpath( - "protoc.js")).absolute() - grpc_node_plugin_path = Path( - RESOURCES_PATH.joinpath("node_modules").joinpath("grpc-tools").joinpath("bin").joinpath( - "grpc_node_plugin")).resolve() - if not os.path.isfile(protoc_node_compiler_path) or not os.path.isfile(grpc_node_plugin_path): - print("Missing required node.js protoc compiler. Retrieving from npm...") - subprocess.run(["npm", "install"], cwd=RESOURCES_PATH) - compiler_args.append("--js_out=import_style=commonjs,binary:{}".format(codegen_dir)) - compiler_args.append("--grpc_out={}".format(codegen_dir)) - compiler_args.append("--plugin=protoc-gen-grpc={}".format(grpc_node_plugin_path)) - compiler = lambda args: subprocess.run([str(protoc_node_compiler_path)] + args) + else: + raise Exception("We only support python target language for proto compiling") if proto_file: compiler_args.append(str(proto_file)) else: compiler_args.extend([str(p) for p in entry_path.glob("**/*.proto")]) + if add_training: + compiler_args.append(str(training_include.joinpath("training.proto"))) + if not compiler(compiler_args): return True else: @@ -128,6 +130,10 @@ def get_address_from_private(private_key): return web3.Account.from_key(private_key).address +def get_current_block_number() -> BlockNumber: + return web3.Web3().eth.block_number + + class add_to_path: def __init__(self, path): self.path = path @@ -142,10 +148,12 @@ def __exit__(self, exc_type, exc_value, traceback): pass -def find_file_by_keyword(directory, keyword): +def find_file_by_keyword(directory, keyword, exclude=None): + if exclude is None: + exclude = [] for root, dirs, files in os.walk(directory): for file in files: - if keyword in file: + if keyword in file and all(e not in file for e in exclude): return file @@ -177,6 +185,6 @@ def safe_extract_proto(spec_tar, protodir): fullname = os.path.join(protodir, m.name) if os.path.exists(fullname): os.remove(fullname) - print("%s removed." % fullname) + print(f"{fullname} removed.") # now it is safe to call extractall - f.extractall(protodir) + f.extractall(path=protodir) diff --git a/testcases/functional_tests/test_sdk_client.py b/testcases/functional_tests/test_sdk_client.py index 4028ab9..f83f59b 100644 --- a/testcases/functional_tests/test_sdk_client.py +++ b/testcases/functional_tests/test_sdk_client.py @@ -1,5 +1,4 @@ import unittest -import shutil import os from snet import sdk @@ -7,20 +6,13 @@ class TestSDKClient(unittest.TestCase): def setUp(self): - self.service_client, self.path_to_pb_files = get_test_service_data() + self.service_client = get_test_service_data() channel = self.service_client.deposit_and_open_channel(123456, 33333) def test_call_to_service(self): result = self.service_client.call_rpc("mul", "Numbers", a=20, b=3) self.assertEqual(60.0, result.value) - def tearDown(self): - try: - shutil.rmtree(self.path_to_pb_files) - print(f"Directory '{self.path_to_pb_files}' has been removed successfully after testing.") - except OSError as e: - print(f"Error: {self.path_to_pb_files} : {e.strerror}") - def get_test_service_data(): config = sdk.config.Config(private_key=os.environ['SNET_TEST_WALLET_PRIVATE_KEY'], @@ -32,9 +24,7 @@ def get_test_service_data(): service_client = snet_sdk.create_service_client(org_id="26072b8b6a0e448180f8c0e702ab6d2f", service_id="Exampleservice", group_name="default_group") - path_to_pb_files = snet_sdk.get_path_to_pb_files(org_id="26072b8b6a0e448180f8c0e702ab6d2f", - service_id="Exampleservice") - return service_client, path_to_pb_files + return service_client if __name__ == '__main__': diff --git a/testcases/utils/run_all_functional.sh b/testcases/utils/run_all_functional.sh index 224fe05..55c423a 100755 --- a/testcases/utils/run_all_functional.sh +++ b/testcases/utils/run_all_functional.sh @@ -1,4 +1,4 @@ -./testcases/utils/reset_environment.sh +#./testcases/utils/reset_environment.sh cd testcases/functional_tests python3 test_sdk_client.py diff --git a/tests/unit_tests/test_account.py b/tests/unit_tests/test_account.py new file mode 100644 index 0000000..4889609 --- /dev/null +++ b/tests/unit_tests/test_account.py @@ -0,0 +1,153 @@ +import os +import unittest +from unittest.mock import MagicMock, patch + +from dotenv import load_dotenv +from web3 import Web3 + +from snet.sdk.account import Account, TransactionError +from snet.sdk.config import Config +from snet.sdk.mpe.mpe_contract import MPEContract + +load_dotenv() + + +class TestAccount(unittest.TestCase): + @patch("snet.sdk.account.get_contract_object") + def setUp(self, mock_get_contract_object): + # Mock main fields + self.mock_web3 = MagicMock(spec=Web3) + self.mock_config = MagicMock(spec=Config) + self.mock_mpe_contract = MagicMock(spec=MPEContract) + + # Mock additional fields + self.mock_web3.eth = MagicMock() + self.mock_web3.net = MagicMock() + self.mock_mpe_contract.contract = MagicMock() + + # Config mock return values + self.mock_config.get.side_effect = lambda key, default=None: { + "private_key": os.getenv("PRIVATE_KEY"), + "signer_private_key": None, + "token_contract_address": None, + }.get(key, default) + + # Mock token contract + self.mock_token_contract = MagicMock() + self.mock_get_contract_object = mock_get_contract_object + self.mock_get_contract_object.return_value = self.mock_token_contract + + self.account = Account(self.mock_web3, self.mock_config, + self.mock_mpe_contract) + + def test_get_nonce(self): + for i in [4, 5]: + self.mock_web3.eth.get_transaction_count.return_value = i + self.account.nonce = 4 + nonce = self.account._get_nonce() + self.assertEqual(nonce, 5) + self.assertEqual(self.account.nonce, 5) + + def test_get_gas_price(self): + # Test different gas price levels + gas_price = 10000000000 + self.mock_web3.eth.gas_price = gas_price + gas_price = self.account._get_gas_price() + self.assertEqual(gas_price, int(gas_price + (gas_price * 1 / 3))) + + gas_price = 16000000000 + self.mock_web3.eth.gas_price = gas_price + gas_price = self.account._get_gas_price() + self.assertEqual(gas_price, int(gas_price + (gas_price * 1 / 5))) + + gas_price = 51200000000 + self.mock_web3.eth.gas_price = 51200000000 + gas_price = self.account._get_gas_price() + self.assertEqual(gas_price, int(gas_price + 7000000000)) + + gas_price = 150000000001 + self.mock_web3.eth.gas_price = 150000000001 + gas_price = self.account._get_gas_price() + self.assertEqual(gas_price, int(gas_price + (gas_price * 1 / 10))) + + # @patch("snet.sdk.web3.Web3.to_hex", side_effect=lambda x: "mock_txn_hash") + # def test_send_signed_transaction(self, mock_to_hex): + # # Mock contract function + # mock_contract_fn = MagicMock() + # mock_contract_fn.return_value.build_transaction.return_value = {"mock": "txn"} + + # # Test transaction sending + # txn_hash = self.account._send_signed_transaction(mock_contract_fn) + # self.assertEqual(txn_hash, "mock_txn_hash") + # self.mock_web3.eth.account.sign_transaction.assert_called_once() + # self.mock_web3.eth.send_raw_transaction.assert_called_once() + + def test_parse_receipt_success(self): + # Mock receipt and event + mock_receipt = MagicMock() + mock_receipt.status = 1 + mock_event = MagicMock() + mock_event.return_value.processReceipt.return_value = [ + {"args": {"key": "value"}} + ] + + result = self.account._parse_receipt(mock_receipt, mock_event) + self.assertEqual(result, '{"key": "value"}') + + def test_parse_receipt_failure(self): + # Mock a failing receipt + mock_receipt = MagicMock() + mock_receipt.status = 0 + + with self.assertRaises(TransactionError) as context: + self.account._parse_receipt(mock_receipt, None) + + self.assertEqual(str(context.exception), "Transaction failed") + self.assertEqual(context.exception.receipt, mock_receipt) + + def test_escrow_balance(self): + self.mock_mpe_contract.balance.return_value = 120000000 + balance = self.account.escrow_balance() + self.assertIsInstance(balance, int) + self.assertEqual(balance, 120000000) + self.mock_mpe_contract.balance.assert_called_once_with( + self.account.address + ) + + def test_deposit_to_escrow_account(self): + self.account.allowance = MagicMock(return_value=0) + self.account.approve_transfer = MagicMock() + self.mock_mpe_contract.deposit.return_value = "0x51ec7c89064d95416be4" + + result = self.account.deposit_to_escrow_account(100) + self.account.approve_transfer.assert_called_once_with(100) + self.mock_mpe_contract.deposit.assert_called_once_with(self.account, + 100) + self.assertEqual(result, "0x51ec7c89064d95416be4") + + def test_approve_transfer(self): + self.mock_web3.eth.gas_price = 10000000000 + self.mock_web3.eth.get_transaction_count.return_value = 1 + result = self.account.approve_transfer(500) + self.assertIsNotNone(result) + self.mock_token_contract.functions.approve.assert_called_once_with( + self.mock_mpe_contract.contract.address, 500 + ) + # def test_approve_transfer(self): + # self.account.send_transaction = MagicMock() + # self.account.send_transaction.return_value = "TxReceipt" + # result = self.account.approve_transfer(500) + # self.assertEqual(result, "TxReceipt") + + def test_allowance(self): + self.mock_token_contract.functions.allowance.return_value.call \ + .return_value = 100 + allowance = self.account.allowance() + self.assertEqual(allowance, 100) + self.mock_token_contract.functions.allowance.assert_called_once_with( + self.account.address, self.mock_mpe_contract.contract.address + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py new file mode 100644 index 0000000..95f4e9e --- /dev/null +++ b/tests/unit_tests/test_config.py @@ -0,0 +1,92 @@ +import unittest + +from snet.sdk.config import Config + + +class TestConfig(unittest.TestCase): + def setUp(self): + self.private_key = "test_private_key" + self.eth_rpc_endpoint = "http://localhost:8545" + self.wallet_index = 1 + self.ipfs_endpoint = "http://custom-ipfs-endpoint.io" + self.mpe_contract_address = "0xMPEAddress" + self.token_contract_address = "0xTokenAddress" + self.registry_contract_address = "0xRegistryAddress" + self.signer_private_key = "signer_key" + + def test_initialization(self): + config = Config( + private_key=self.private_key, + eth_rpc_endpoint=self.eth_rpc_endpoint, + wallet_index=self.wallet_index, + ipfs_endpoint=self.ipfs_endpoint, + concurrency=False, + force_update=True, + mpe_contract_address=self.mpe_contract_address, + token_contract_address=self.token_contract_address, + registry_contract_address=self.registry_contract_address, + signer_private_key=self.signer_private_key + ) + + self.assertEqual(config["private_key"], + self.private_key) + self.assertEqual(config["eth_rpc_endpoint"], + self.eth_rpc_endpoint) + self.assertEqual(config["wallet_index"], + self.wallet_index) + self.assertEqual(config["ipfs_endpoint"], + self.ipfs_endpoint) + self.assertEqual(config["mpe_contract_address"], + self.mpe_contract_address) + self.assertEqual(config["token_contract_address"], + self.token_contract_address) + self.assertEqual(config["registry_contract_address"], + self.registry_contract_address) + self.assertEqual(config["signer_private_key"], + self.signer_private_key) + self.assertEqual(config["lighthouse_token"], " ") + self.assertFalse(config["concurrency"]) + self.assertTrue(config["force_update"]) + + def test_default_values(self): + config = Config( + private_key=self.private_key, + eth_rpc_endpoint=self.eth_rpc_endpoint + ) + + self.assertEqual(config["wallet_index"], 0) + self.assertEqual(config["ipfs_endpoint"], + "/dns/ipfs.singularitynet.io/tcp/80/") + self.assertTrue(config["concurrency"]) + self.assertFalse(config["force_update"]) + self.assertIsNone(config["mpe_contract_address"]) + self.assertIsNone(config["token_contract_address"]) + self.assertIsNone(config["registry_contract_address"]) + self.assertIsNone(config["signer_private_key"]) + + def test_get_method(self): + config = Config(private_key=self.private_key, + eth_rpc_endpoint=self.eth_rpc_endpoint) + + self.assertEqual(config.get("private_key"), self.private_key) + self.assertEqual(config.get("non_existent_key", + "default_value"), "default_value") + self.assertIsNone(config.get("non_existent_key")) + + def test_get_ipfs_endpoint(self): + config = Config(private_key=self.private_key, + eth_rpc_endpoint=self.eth_rpc_endpoint) + self.assertEqual(config.get_ipfs_endpoint(), + "/dns/ipfs.singularitynet.io/tcp/80/") + + config_with_custom_ipfs = Config( + private_key=self.private_key, + eth_rpc_endpoint=self.eth_rpc_endpoint, + ipfs_endpoint=self.ipfs_endpoint + ) + self.assertEqual(config_with_custom_ipfs.get_ipfs_endpoint(), + self.ipfs_endpoint) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/test_files.zip b/tests/unit_tests/test_files.zip new file mode 100644 index 0000000..ed9de4b Binary files /dev/null and b/tests/unit_tests/test_files.zip differ diff --git a/tests/unit_tests/test_lib_generator.py b/tests/unit_tests/test_lib_generator.py new file mode 100644 index 0000000..7d96478 --- /dev/null +++ b/tests/unit_tests/test_lib_generator.py @@ -0,0 +1,161 @@ +import os +from pathlib import Path +import unittest +from unittest.mock import MagicMock, Mock, patch + +from snet.sdk.client_lib_generator import ClientLibGenerator +from snet.sdk.storage_provider.storage_provider import StorageProvider + + +class TestClientLibGenerator(unittest.TestCase): + def setUp(self): + self.mock_metadata_provider = Mock(spec=StorageProvider) + self.org_id = "26072b8b6a0e448180f8c0e702ab6d2f" + self.service_id = "Exampleservice" + self.language = "python" + self.protodir = Path.home().joinpath(".snet_test") + self.generator = ClientLibGenerator( + metadata_provider=self.mock_metadata_provider, + org_id=self.org_id, + service_id=self.service_id, + protodir=self.protodir + ) + + @patch("pathlib.Path.mkdir") + def test_generate_directories_by_params_by_absolute_path(self, mock_mkdir): + expected_library_dir = self.protodir.joinpath( + self.org_id, self.service_id, self.language + ) + self.generator.generate_directories_by_params() + self.assertEqual(self.generator.protodir, expected_library_dir) + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + + @patch("pathlib.Path.mkdir") + def test_generate_directories_by_params_by_relative_path(self, mock_mkdir): + self.generator.protodir = Path(".snet_test") + expected_library_dir = Path.cwd().joinpath(self.generator.protodir, + self.org_id, + self.service_id, + self.language) + self.generator.generate_directories_by_params() + self.assertEqual(self.generator.protodir, expected_library_dir) + mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) + + def test_create_service_client_libraries_path(self): + mock_protodir = Mock(spec=Path) + self.generator.protodir = mock_protodir + mock_library_path = Mock(spec=Path) + mock_protodir.joinpath.return_value = mock_library_path + + # Call the method + self.generator.create_service_client_libraries_path() + + # Assert that joinpath and mkdir were called with correct arguments + mock_protodir.joinpath.assert_called_once_with(self.org_id, + self.service_id, + self.language) + mock_library_path.mkdir.assert_called_once_with(parents=True, + exist_ok=True) + + # Assert that the protodir is updated correctly + self.assertEqual(self.generator.protodir, mock_library_path) + + def test_receive_proto_files_success(self): + # Set up mocks + mock_metadata = {"service_api_source": None, + "model_ipfs_hash": os.getenv("MODEL_IPFS_HASH")} + self.mock_metadata_provider.fetch_service_metadata \ + .return_value = mock_metadata + self.generator.protodir = Mock() + self.generator.protodir.exists.return_value = True + + # Call the method + self.generator.receive_proto_files() + + # Check method calls + service_api_source = (mock_metadata.get("service_api_source") or + mock_metadata.get("model_ipfs_hash")) + self.mock_metadata_provider.fetch_service_metadata \ + .assert_called_once_with( + org_id=self.org_id, + service_id=self.service_id + ) + self.mock_metadata_provider.fetch_and_extract_proto \ + .assert_called_once_with( + service_api_source, + self.generator.protodir + ) + + def test_receive_proto_files_failed(self): + self.generator.protodir = Mock() + self.generator.protodir.exists.return_value = False + + with self.assertRaises(Exception) as context: + self.generator.receive_proto_files() + + self.assertEqual(str(context.exception), + "Directory for storing proto files not found") + + # @patch("snet.sdk.utils.utils.compile_proto") + # @patch("snet.sdk.client_lib_generator.ClientLibGenerator.generate_directories_by_params") + # @patch("snet.sdk.client_lib_generator.ClientLibGenerator.receive_proto_files") + # @patch("pathlib.Path.exists") + # @patch("pathlib.Path.mkdir") + # @patch("snet.sdk.client_lib_generator.StorageProvider.fetch_and_extract_proto") + # def test_generate_client_library_success(self, + # mock_fetch_and_extract_proto, + # mock_mkdir, + # mock_exists, + # mock_receive_proto_files, + # mock_generate_directories_by_params, + # mock_compile_proto): + # # Мокаем директорию и путь + # # self.generator.protodir = Mock() + # # self.generator.protodir.exists.return_value = True + # mock_base_dir = MagicMock(spec=Path) + # mock_service_dir = MagicMock(spec=Path) + # example_proto_file = mock_service_dir.joinpath("example_service.proto") + + # # Убедимся, что директория существует + # mock_exists.return_value = True + + # # Подменяем возвращаемое значение при вызове joinpath + # mock_base_dir.joinpath.return_value = mock_service_dir + # mock_service_dir.joinpath.return_value = example_proto_file + # mock_service_dir.exists.return_value = True # Прототипный файл существует + + # # Настроим mock для метода fetch_and_extract_proto + # mock_fetch_and_extract_proto.return_value = None # Симулируем успешный экстракт + + # # Настроим генератор + # self.generator.protodir = mock_base_dir + # self.generator.protodir.joinpath.return_value = mock_service_dir # Возвращаем mock для директории сервиса + + # # Настроим моки для вызова методов + # mock_generate_directories_by_params.return_value = None + # mock_receive_proto_files.return_value = None + + # # Запускаем метод + # self.generator.generate_client_library() + + # # Проверяем, что директории и файл созданы + # mock_generate_directories_by_params.assert_called_once() + # mock_receive_proto_files.assert_called_once() + + # # Убедимся, что compile_proto получил правильный путь + # mock_compile_proto.assert_called_once_with( + # entry_path=str(mock_service_dir), + # codegen_dir=str(mock_service_dir), + # target_language="python" + # ) + + @patch("builtins.print") + def test_generate_client_library_handles_exception(self, mock_print): + with patch("snet.sdk.client_lib_generator.ClientLibGenerator.generate_directories_by_params", # noqa E501 + side_effect=Exception("Test exception")): + self.generator.generate_client_library() + mock_print.assert_called_once_with("Test exception") + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/unit_tests/test_service_client.py b/tests/unit_tests/test_service_client.py new file mode 100644 index 0000000..d04ba37 --- /dev/null +++ b/tests/unit_tests/test_service_client.py @@ -0,0 +1,165 @@ +from pathlib import Path +import unittest +from unittest.mock import MagicMock, Mock, patch, create_autospec + +from web3 import Web3 + +from snet.sdk.account import Account +from snet.sdk.mpe.mpe_contract import MPEContract +from snet.sdk.mpe.payment_channel_provider import PaymentChannelProvider +from snet.sdk.service_client import ServiceClient +from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata + + +class TestServiceClient(unittest.TestCase): + def setUp(self): + self.mock_org_id = "26072b8b6a0e448180f8c0e702ab6d2f" + self.mock_service_id = "Exampleservice" + self.mock_service_metadata = MagicMock(spec=MPEServiceMetadata) + self.mock_group = { + 'free_calls': 0, + 'free_call_signer_address': '0x7DF35C98f41F3Af0df1dc4c7F7D4C19a71Dd059F', + 'daemon_addresses': [ + '0x0709e9b78756b740ab0c64427f43f8305fd6d1a7' + ], + 'pricing': [ + { + 'default': True, + 'price_model': 'fixed_price', + 'price_in_cogs': 1 + } + ], + 'endpoints': [ + 'http://node1.naint.tech:62400' + ], + 'group_id': '/mb90Qs8VktxGQmU0uRu0bSlGgqeDlYrKrs+WbsOvOQ=', + 'group_name': 'default_group', + 'payment': { + 'payment_address': '0x0709e9B78756B740ab0C64427f43f8305fD6D1A7', + 'payment_expiration_threshold': 40320, + 'payment_channel_storage_type': 'etcd', + 'payment_channel_storage_client': { + 'endpoints': [ + 'https://127.0.0.1:2379' + ], + 'request_timeout': '3s', + 'connection_timeout': '5s' + } + } + } + self.mock_service_stub = MagicMock() + self.mock_payment_strategy = MagicMock() + self.mock_options = { + 'free_call_auth_token-bin': '', + 'free-call-token-expiry-block': 0, + 'email': '', + 'concurrency': False, + "endpoint": "http://localhost:5000" + } + self.mock_mpe_contract = MagicMock(spec=MPEContract) + self.mock_mpe_contract.contract = MagicMock() + self.mock_account = MagicMock(spec=Account) + self.mock_account.signer_private_key = MagicMock() + self.mock_sdk_web3 = MagicMock(spec=Web3) + self.mock_sdk_web3.eth = MagicMock() + self.mock_pb2_module = MagicMock() + self.mock_payment_channel_provider = MagicMock( + spec=PaymentChannelProvider + ) + self.mock_path_to_pb_files = MagicMock(spec=Path) + + self.client = ServiceClient( + self.mock_org_id, + self.mock_service_id, + self.mock_service_metadata, + self.mock_group, + self.mock_service_stub, + self.mock_payment_strategy, + self.mock_options, + self.mock_mpe_contract, + self.mock_account, + self.mock_sdk_web3, + self.mock_pb2_module, + self.mock_payment_channel_provider, + self.mock_path_to_pb_files + ) + + def test_call_rpc(self): + # Set up mocks for service and pb2_module + mock_rpc_method = MagicMock(return_value="value: 8") + self.client.service = MagicMock() + self.client.service.mul = mock_rpc_method + self.mock_pb2_module.Numbers = MagicMock() + + # Call the method + result = self.client.call_rpc("mul", "Numbers", a=2, b=4) + + # Assert that Numbers was called with correct arguments + self.mock_pb2_module.Numbers.assert_called_once_with(a=2, b=4) + mock_rpc_method.assert_called_once_with( + self.mock_pb2_module.Numbers.return_value + ) + self.assertEqual(result, mock_rpc_method.return_value) + + @patch("snet.sdk.service_client.grpc.insecure_channel") + def test_get_grpc_channel_http(self, mock_insecure_channel): + channel = self.client._get_grpc_channel() + mock_insecure_channel.assert_called_once_with("localhost:5000") + self.assertEqual(channel, mock_insecure_channel.return_value) + + @patch("snet.sdk.service_client.grpc.ssl_channel_credentials") + @patch("snet.sdk.service_client.grpc.secure_channel") + def test_get_grpc_channel_https(self, + mock_secure_channel, + mock_ssl_channel_credentials): + self.mock_options["endpoint"] = "https://localhost:5000" + channel = self.client._get_grpc_channel() + mock_ssl_channel_credentials.assert_called_once() + mock_secure_channel.assert_called_once_with( + "localhost:5000", + mock_ssl_channel_credentials.return_value + ) + self.assertEqual(channel, mock_secure_channel.return_value) + + def test_filter_existing_channels(self): + mock_existing_channel = MagicMock(channel_id=1) + mock_new_channel_1 = MagicMock(channel_id=2) + mock_new_channel_2 = MagicMock(channel_id=3) + self.client.payment_channels = [mock_existing_channel] + result = self.client._filter_existing_channels_from_new_payment_channels( # noqa E501 + [mock_existing_channel, mock_new_channel_1, mock_new_channel_2] + ) + self.assertEqual(result, [mock_new_channel_1, mock_new_channel_2]) + + def test_get_current_block_number(self): + expected_result = Mock(return_value=12345) + self.client.sdk_web3.eth.block_number = expected_result + result = self.client.get_current_block_number() + self.assertEqual(result, expected_result) + + def test_generate_signature(self): + message = b"test_message" + mock_signature = MagicMock() + self.client.sdk_web3.eth.account.signHash = MagicMock( + return_value=MagicMock(signature=mock_signature) + ) + result = self.client.generate_signature(message) + self.assertEqual(result, bytes(mock_signature)) + + @patch("snet.sdk.service_client.web3.Web3.solidity_keccak") + def test_generate_training_signature(self, mock_solidity_keccak): + text = "test_text" + address = "test_address" + block_number = "test_block_number" + mock_solidity_keccak.return_value = b"test_message" + mock_signature = MagicMock() + self.client.sdk_web3.eth.account.signHash = MagicMock( + return_value=MagicMock(signature=mock_signature) + ) + result = self.client.generate_training_signature(text, address, + block_number) + self.assertEqual(result, mock_signature) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/test_training_v2.py b/tests/unit_tests/test_training_v2.py new file mode 100644 index 0000000..40d60f4 --- /dev/null +++ b/tests/unit_tests/test_training_v2.py @@ -0,0 +1,48 @@ +import os.path +import unittest +from unittest.mock import patch, MagicMock + +from snet.sdk.training.responses import MethodMetadata +from snet.sdk.training.training import Training +from snet.sdk.service_client import ServiceClient +from snet.sdk.training.exceptions import WrongDatasetException + + +class TestTrainingV2(unittest.TestCase): + def setUp(self): + self.mock_service_client = MagicMock(spec=ServiceClient) + self.training = Training(self.mock_service_client, training_added=True) + self.file_path = os.path.join(os.path.dirname(__file__), "test_files.zip") + self.get_metadata_path = "snet.sdk.training.training.Training.get_method_metadata" + + def test_check_dataset_positive(self): + method_metadata = MethodMetadata("test", + 5, + 50, + 10, + 25, + "jpg, png, wav", + "zip", + "test") + + with patch(self.get_metadata_path, return_value=method_metadata): + try: + self.training._check_dataset("test", self.file_path) + except WrongDatasetException as e: + print(e) + assert False + assert True + + def test_check_dataset_negative(self): + method_metadata = MethodMetadata("test", + 5, + 10, + 10, + 5, + "png, mp3, txt", + "zip", + "test") + + with patch(self.get_metadata_path, return_value=method_metadata): + with self.assertRaises(WrongDatasetException): + self.training._check_dataset("test", self.file_path)