diff --git a/nvflare/edge/simulation/config.py b/nvflare/edge/simulation/config.py index 12d5eead0d..88c0604d78 100644 --- a/nvflare/edge/simulation/config.py +++ b/nvflare/edge/simulation/config.py @@ -41,6 +41,7 @@ def load_class(class_path) -> Type: class ConfigParser: def __init__(self, config_file: str): self.job_name = None + self.allow_self_signed = False self.get_job_timeout = None self.processor = None self.endpoint = None @@ -116,6 +117,11 @@ def parse(self, config_file: str): check_positive_number("get_job_timeout", n) self.get_job_timeout = n + n = config.get("allow_self_signed", False) + if not isinstance(n, bool): + raise TypeError(f"allow_self_signed must be a bool, but got {type(n)}") + self.allow_self_signed = n + def _variable_substitution(self, args: Any, variables: dict) -> Any: if isinstance(args, dict): return {k: self._variable_substitution(v, variables) for k, v in args.items()} diff --git a/nvflare/edge/simulation/feg_api.py b/nvflare/edge/simulation/feg_api.py index 2ffd7a31bc..a9dcfe89aa 100644 --- a/nvflare/edge/simulation/feg_api.py +++ b/nvflare/edge/simulation/feg_api.py @@ -31,7 +31,7 @@ class FegApi: - def __init__(self, endpoint: str, device_info: DeviceInfo, user_info: UserInfo): + def __init__(self, endpoint: str, device_info: DeviceInfo, user_info: UserInfo, allow_self_signed: bool = False): self.endpoint = endpoint self.device_info = device_info self.user_info = user_info @@ -46,6 +46,7 @@ def __init__(self, endpoint: str, device_info: DeviceInfo, user_info: UserInfo): HttpHeaderKey.DEVICE_INFO: device_qs, HttpHeaderKey.USER_INFO: user_qs, } + self.allow_self_signed = allow_self_signed def get_job(self, request: JobRequest) -> JobResponse: return self._do_post( @@ -91,7 +92,12 @@ def get_selection(self, request: SelectionRequest) -> SelectionResponse: ) def _do_post(self, clazz, url, params, body): - response = requests.post(url, params=params, json=body, headers=self.common_headers) + import urllib3 + if self.allow_self_signed: + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + response = requests.post( + url, params=params, json=body, headers=self.common_headers, verify=not self.allow_self_signed + ) code = response.status_code if code == 200: return clazz(**response.json()) diff --git a/nvflare/edge/simulation/run_device_simulator.py b/nvflare/edge/simulation/run_device_simulator.py index a1b517725f..ae904a0c36 100644 --- a/nvflare/edge/simulation/run_device_simulator.py +++ b/nvflare/edge/simulation/run_device_simulator.py @@ -63,6 +63,7 @@ def _send_request_to_proxy(request, device: SimulatedDevice, parser: ConfigParse endpoint=parser.get_endpoint(), device_info=device.get_device_info(), user_info=device.get_user_info(), + allow_self_signed=parser.allow_self_signed, ) if isinstance(request, TaskRequest): return api.get_task(request)