diff --git a/codexctl/__init__.py b/codexctl/__init__.py index f3f3002..230a1ae 100644 --- a/codexctl/__init__.py +++ b/codexctl/__init__.py @@ -98,7 +98,7 @@ def call_func(self, function: str, args: dict) -> None: args["out"] = os.getcwd() + "/extracted" logger.debug(f"Extracting {args['file']} to {args['out']}") - image, volume = get_update_image(args["file"]) + image, volume = get_update_image(args["file"], logger=logger) image.seek(0) with open(args["out"], "wb") as f: @@ -131,7 +131,7 @@ def call_func(self, function: str, args: dict) -> None: ) try: - image, volume = get_update_image(args["file"]) + image, volume = get_update_image(args["file"], logger=logger) inode = volume.inode_at(args["target_path"]) except FileNotFoundError: @@ -192,6 +192,9 @@ def call_func(self, function: str, args: dict) -> None: "Psutil is required for SSH access. Please install it." ) remote = True + else: + if function == "transfer": + raise SystemError("You can't transfer files alredy on your device!") from .device import DeviceManager from .server import get_available_version @@ -199,6 +202,7 @@ def call_func(self, function: str, args: dict) -> None: remarkable = DeviceManager( remote=remote, address=args["address"], + port=args["port"], logger=self.logger, authentication=args["password"], ) @@ -208,7 +212,11 @@ def call_func(self, function: str, args: dict) -> None: elif version == "toltec": version = self.updater.get_toltec_version(remarkable.hardware) - if function == "status": + if function == "transfer": + remarkable.transfer_file_to_remote(args["file"], args["destination"]) + print("Done!") + + elif function == "status": beta, prev, current, version_id = remarkable.get_device_status() print( f"\nCurrent version: {current}\nOld update engine: {prev}\nBeta active: {beta}\nVersion id: {version_id}" @@ -232,10 +240,11 @@ def call_func(self, function: str, args: dict) -> None: # Do we have a specific update file to serve? update_file = version if os.path.isfile(version) else None + manual_dd_update = False def version_lookup(version: str | None) -> re.Match[str] | None: return re.search(r"\b\d+\.\d+\.\d+\.\d+\b", cast(str, version)) - + version_number = version_lookup(version) if not version_number: @@ -260,22 +269,13 @@ def version_lookup(version: str | None) -> re.Match[str] | None: if device_version_uses_new_engine: if not update_file_requires_new_engine: - raise SystemError( - "Cannot downgrade to this version as it uses the old update engine, please manually downgrade." - ) - # TODO: Implement manual downgrading. + manual_dd_update = True # `codexctl download --out . 3.11.2.5` # `codexctl extract --out 3.11.2.5.img 3.11.2.5_reMarkable2-qLFGoqPtPL.signed` # `codexctl transfer 3.11.2.5.img ~/root` # `dd if=/home/root/3.11.2.5.img of=/dev/mmcblk2p2` (depending on fallback partition) # `codexctl restore` - else: - if update_file_requires_new_engine: - raise SystemError( - "This version requires the new update engine, please upgrade your device to version 3.11.2.5 first." - ) - ############################################################# if not update_file_requires_new_engine: @@ -318,7 +318,30 @@ def version_lookup(version: str | None) -> re.Match[str] | None: ) if device_version_uses_new_engine: - remarkable.install_sw_update(update_file) + if not manual_dd_update: + remarkable.install_sw_update(update_file) + else: + try: + from .analysis import get_update_image + except ImportError: + raise ImportError( + "remarkable_update_image is required for this update. Please install it!" + ) + + out_image_file = f"{version_number}.img" + + logger.debug(f"Extracting {update_file} to ./{out_image_file}") + image, volume = get_update_image(update_file, logger=logger) + image.seek(0) + + with open(out_image_file, "wb") as f: + f.write(image.read()) + + print("Now installing from .img") + + remarkable.install_manual_update(out_image_file) + + os.remove(out_image_file) else: remarkable.install_ohma_update(update_file) @@ -360,9 +383,24 @@ def main() -> None: help="Specify password or path to SSH key for remote access", dest="password", ) + parser.add_argument( + "--port", + required=False, + type=int, + default=22, + help="Specify specific SSH port, shouldn't be needed unless you've changed it." + ) subparsers = parser.add_subparsers(dest="command") subparsers.required = True # This fixes a bug with older versions of python + ### Transfer subcommand + transfer = subparsers.add_parser( + "transfer", + help="Transfer a file from your host to the device", + ) + transfer.add_argument("file", help="Location of file to transfer") + transfer.add_argument("destination", help="Where the file should be put on the device") + ### Install subcommand install = subparsers.add_parser( "install", @@ -382,7 +420,7 @@ def main() -> None: "-d", help="Hardware to download for", required=True, - dest="hardware", + dest="hardware" ) ### Backup subcommand diff --git a/codexctl/analysis.py b/codexctl/analysis.py index 8e8fe85..aa9bda1 100644 --- a/codexctl/analysis.py +++ b/codexctl/analysis.py @@ -1,12 +1,10 @@ import ext4 -import warnings -import errno from remarkable_update_image import UpdateImage from remarkable_update_image import UpdateImageSignatureException -def get_update_image(file: str): +def get_update_image(file: str, logger): """Extracts files from an update image (<3.11 currently)""" image = UpdateImage(file) @@ -20,14 +18,14 @@ def get_update_image(file: str): image.verify(inode.open().read()) except UpdateImageSignatureException: - warnings.warn("Signature doesn't match contents", RuntimeWarning) + logger.warning("Signature doesn't match contents", RuntimeWarning) except FileNotFoundError: - warnings.warn("Public key missing", RuntimeWarning) + logger.warning("Public key missing", RuntimeWarning) except OSError as e: if e.errno != errno.ENOTDIR: raise - warnings.warn("Unable to open public key", RuntimeWarning) + logger.warning("Unable to open public key", RuntimeWarning) return image, volume diff --git a/codexctl/device.py b/codexctl/device.py index 5730f4c..17cf7d4 100644 --- a/codexctl/device.py +++ b/codexctl/device.py @@ -17,18 +17,20 @@ class DeviceManager: def __init__( - self, logger=None, remote=False, address=None, authentication=None + self, logger=None, remote=False, address=None, authentication=None, port=22 ) -> None: """Initializes the DeviceManager for codexctl Args: remote (bool, optional): Whether the device is remote. Defaults to False. address (bool, optional): Known IP of remote device, if applicable. Defaults to None. + port (int, optional): Known port of remote device SSH service. Defaults to 22. logger (logger, optional): Logger object for logging. Defaults to None. Authentication (str, optional): Authentication method. Defaults to None. """ self.logger = logger self.address = address + self.port = port self.authentication = authentication self.client = None @@ -37,12 +39,9 @@ def __init__( if remote: self.client = self.connect_to_device( - authentication=authentication, remote_address=address + authentication=authentication, remote_address=address, port=port ) - self.client.authentication = authentication - self.client.address = address - ftp = self.client.open_sftp() with ftp.file("/sys/devices/soc0/machine") as file: machine_contents = file.read().decode("utf-8").strip("\n") @@ -110,42 +109,42 @@ def get_remarkable_address(self) -> str: str: IP address of the remarkable device """ - if self.check_is_address_reachable("10.11.99.1"): + if self.check_if_address_reachable("10.11.99.1", self.port): return "10.11.99.1" while True: remote_ip = input("Please enter the IP of the remarkable device: ") - if self.check_is_address_reachable(remote_ip): + if self.check_if_address_reachable(remote_ip, self.port): return remote_ip print(f"Error: Device {remote_ip} is not reachable. Please try again.") - def check_is_address_reachable(self, remote_ip="10.11.99.1") -> bool: + def check_if_address_reachable(self, remote_ip="10.11.99.1", remote_port=22) -> bool: """Checks if the given IP address is reachable over SSH Args: remote_ip (str, optional): IP to check. Defaults to '10.11.99.1'. - + remote_port (int, optional): Port to check. Defaults to `22`. Returns: bool: True if reachable, False otherwise """ self.logger.debug(f"Checking if {remote_ip} is reachable") + try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(1) - sock.connect((remote_ip, 22)) + sock.connect((remote_ip, remote_port)) sock.shutdown(2) return True - - except Exception: + except FileNotFoundError: self.logger.debug(f"Device {remote_ip} is not reachable") return False def connect_to_device( - self, remote_address=None, authentication=None + self, remote_address=None, authentication=None, port=22 ) -> paramiko.client.SSHClient: """Connects to the device using the given IP address @@ -161,13 +160,13 @@ def connect_to_device( remote_address = self.get_remarkable_address() self.address = remote_address # For future reference else: - if self.check_is_address_reachable(remote_address) is False: + if self.check_if_address_reachable(remote_address, port) is False: raise SystemError(f"Error: Device {remote_address} is not reachable!") client = paramiko.client.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - if authentication: + if authentication != None: self.logger.debug(f"Using authentication: {authentication}") try: if os.path.isfile(authentication): @@ -181,9 +180,16 @@ def connect_to_device( self.logger.debug( f"Attempting to connect to {remote_address} with password {authentication}" ) - client.connect( - remote_address, username="root", password=authentication - ) + + if authentication == " ": + transport = paramiko.transport.Transport((remote_address, port)) + transport.start_client() + transport.auth_none("root") + client._transport = transport + else: + client.connect( + remote_address, username="root", password=authentication, port=port + ) except paramiko.ssh_exception.AuthenticationException: print("Incorrect password or ssh path given in arguments!") @@ -425,6 +431,85 @@ def restore_previous_version(self) -> None: self.logger.debug("Restore script ran") + def transfer_file_to_remote(self, file_location: str, destination: str): + """ + Tranfers file at file_location to destination on devicec + """ + ftp_client = self.client.open_sftp() + + print(f"Uploading {file_location} to {destination}") + + ftp_client.put( + file_location, destination, callback=self.output_put_progress + ) + + def install_manual_update(self, version_file: str) -> None: + if self.client: + print(f"Uploading {version_file} image") + + out_image_file = f"/tmp/{out_image_file}" + + self.transfer_file_to_remote(version_file, destination=out_image_file) + + _stdin, stdout, _stderr = self.client.exec_command("/sbin/fw_printenv -n active_partition") + # TODO Before merge: Make this utilise the mount command instead + + fallback_partition = f"mmcblk2p{stdout}" + + print("Now running dd to overwrite the fallback partition") + + _stdin, stdout, _stderr = self.client.exec_command(f"dd if={version_file} of=/dev/{fallback_partition}") + + self.logger.debug( + f'Stdout of dd is {_stderr.readlines()}' + ) + + #### Now disable automatic updates + + self.client.exec_command("sleep 1 && reboot") # Should be enough + self.client.close() + + time.sleep( + 2 + ) # Somehow the code runs faster than the time it takes for the device to reboot + + print("Trying to connect to device") + + while not self.check_if_address_reachable(self.address, self.port): + time.sleep(1) + + self.client = self.connect_to_device( + remote_address=self.address, authentication=self.authentication, port=self.port + ) + self.client.exec_command("systemctl stop swupdate memfaultd") + + print( + "Update complete and update service disabled, restart device to enable it" + ) + + else: + stdout = subprocess.run(['/sbin/fw_printenv', '-n', 'active_partition'], stdout=subprocess.PIPE).stdout.decode().strip() + # TODO Before merge: Make this utilise the mount command instead + + fallback_partition = f"mmcblk2p{stdout}" + + print("Now running dd to overwrite the fallback partition") + + with subprocess.Popen( + f"dd if={version_file} of=/dev/{fallback_partition}", + text=True, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env={"PATH": "/bin:/usr/bin:/sbin"}, + ) as process: + self.logger.debug( + f'Stdout of update checking service is {"".join(process.stderr.readlines())}' + ) + + print("Update complete and device rebooting") + os.system("reboot") + def reboot_device(self) -> None: REBOOT_CODE = """ if systemctl is-active --quiet tarnish.service; then @@ -466,8 +551,6 @@ def install_sw_update(self, version_file: str) -> None: command = f'/usr/bin/swupdate -v -i VERSION_FILE -k /usr/share/swupdate/swupdate-payload-key-pub.pem -H "{self.hardware}:1.0" -e "stable,copy1"' if self.client: - ftp_client = self.client.open_sftp() - print(f"Uploading {version_file} image") out_location = f"/tmp/{os.path.basename(version_file)}.swu" @@ -510,11 +593,11 @@ def install_sw_update(self, version_file: str) -> None: print("Trying to connect to device") - while not self.check_is_address_reachable(self.address): + while not self.check_if_address_reachable(self.address, self.port): time.sleep(1) self.client = self.connect_to_device( - remote_address=self.address, authentication=self.authentication + remote_address=self.address, authentication=self.authentication, port=self.port ) self.client.exec_command("systemctl stop swupdate memfaultd") @@ -628,11 +711,11 @@ def install_ohma_update(self, version_available: dict) -> None: print("Trying to connect to device") - while not self.check_is_address_reachable(self.address): + while not self.check_if_address_reachable(self.address, self.port): time.sleep(1) self.client = self.connect_to_device( - remote_address=self.address, authentication=self.authentication + remote_address=self.address, authentication=self.authentication, port=self.port ) self.client.exec_command("systemctl stop update-engine") @@ -674,6 +757,6 @@ def output_put_progress(transferred: int, toBeTransferred: int) -> None: """Used for displaying progress for paramiko ftp.put function""" print( - f"Transferring progress{int((transferred / toBeTransferred) * 100)}%", + f"Transferring progress {int((transferred/toBeTransferred)*100)}%", end="\r", )