diff --git a/cryptos/network.py b/cryptos/network.py index 54d7f3b..aa7829f 100644 --- a/cryptos/network.py +++ b/cryptos/network.py @@ -21,6 +21,7 @@ 'test': b'\x0b\x11\x09\x07', } + @dataclass class NetworkEnvelope: command: bytes @@ -58,11 +59,11 @@ def encode(self): # encode the command assert len(self.command) <= 12 out += [self.command] - out += [b'\x00' * (12 - len(self.command))] # command padding + out += [b'\x00' * (12 - len(self.command))] # command padding # encode the payload - assert len(self.payload) <= 2**32 # in practice reference client nodes will reject >= 32MB... - out += [len(self.payload).to_bytes(4, 'little')] # payload length - out += [sha256(sha256(self.payload))[:4]] # checksum + assert len(self.payload) <= 2**32 # in practice reference client nodes will reject >= 32MB... + out += [len(self.payload).to_bytes(4, 'little')] # payload length + out += [sha256(sha256(self.payload))[:4]] # checksum out += [self.payload] return b''.join(out) @@ -75,6 +76,7 @@ def stream(self): # Specific types of commands and their payload encoder/decords follow # ----------------------------------------------------------------------------- + @dataclass class NetAddrStruct: """ @@ -82,7 +84,7 @@ class NetAddrStruct: currently assumes IPv4 address """ services: int = 0 - ip: bytes = b'\x00\x00\x00\x00' # IPv4 address + ip: bytes = b'\x00\x00\x00\x00' # IPv4 address port: int = 8333 def encode(self): @@ -107,9 +109,9 @@ class VersionMessage: """ # header information - version: int = 70015 # specifies what messages may be communicated - services: int = 0 # info about what capabilities are available - timestamp: int = None # 8 bytes Unix timestamp in little-endian + version: int = 70015 # specifies what messages may be communicated + services: int = 0 # info about what capabilities are available + timestamp: int = None # 8 bytes Unix timestamp in little-endian # receiver net_addr receiver: NetAddrStruct = field(default_factory=NetAddrStruct) # sender net_addr @@ -119,10 +121,10 @@ class VersionMessage: uint64_t Node random nonce, randomly generated every time a version packet is sent. This nonce is used to detect connections to self. """ - nonce: bytes = None # 8 bytes of nonce - user_agent: bytes = None # var_str: User Agent - latest_block: int = 0 # "The last block received by the emitting node" - relay: bool = False # Whether the remote peer should announce relayed transactions or not, see BIP 0037 + nonce: bytes = None # 8 bytes of nonce + user_agent: bytes = None # var_str: User Agent + latest_block: int = 0 # "The last block received by the emitting node" + relay: bool = False # Whether the remote peer should announce relayed transactions or not, see BIP 0037 command: str = field(init=False, default=b'version') @classmethod @@ -158,6 +160,7 @@ def encode(self): return b''.join(out) + @dataclass class VerAckMessage: """ @@ -174,6 +177,7 @@ def decode(cls, s): def encode(self): return b'' + @dataclass class PingMessage: """ @@ -193,6 +197,7 @@ def decode(cls, s): def encode(self): return self.nonce + @dataclass class PongMessage: """ @@ -212,15 +217,16 @@ def decode(cls, s): def encode(self): return self.nonce + @dataclass class GetHeadersMessage: """ https://en.bitcoin.it/wiki/Protocol_documentation#getheaders """ - version: int = 70015 # uint32_t protocol version - num_hashes: int = 1 # var_int, number of block locator hash entries; can be >1 if there is a chain split - start_block: bytes = None # char[32] block locator object - end_block: bytes = None # char[32] hash of the last desired block header; set to zero to get as many blocks as possible + version: int = 70015 # uint32_t protocol version + num_hashes: int = 1 # var_int, number of block locator hash entries; can be >1 if there is a chain split + start_block: bytes = None # char[32] block locator object + end_block: bytes = None # char[32] hash of the last desired block header; set to zero to get as many blocks as possible command: str = field(init=False, default=b'getheaders') def __post_init__(self): @@ -232,10 +238,11 @@ def encode(self): out = [] out += [self.version.to_bytes(4, 'little')] out += [encode_varint(self.num_hashes)] - out += [self.start_block[::-1]] # little-endian - out += [self.end_block[::-1]] # little-endian + out += [self.start_block[::-1]] # little-endian + out += [self.end_block[::-1]] # little-endian return b''.join(out) + @dataclass class HeadersMessage: """ @@ -266,17 +273,20 @@ def decode(cls, s): # A super lightweight baby node follows # ----------------------------------------------------------------------------- + class SimpleNode: - def __init__(self, host: str, net: str, verbose: int = 0): + def __init__(self, net: str, verbose: int = 0): self.net = net self.verbose = verbose - port = {'main': 8333, 'test': 18333}[net] + self.port = {'main': 8333, 'test': 18333}[net] self.socket = socket.socket() - self.socket.connect((host, port)) self.stream = self.socket.makefile('rb', None) + def connect(self, host): + self.socket.connect((host, self.port)) + def send(self, message): env = NetworkEnvelope(message.command, message.encode(), net=self.net) if self.verbose: @@ -289,9 +299,40 @@ def read(self): print(f"receiving: {env}") return env + def listen(self, *message_classes): + command = None + command_to_class = {m.command: m for m in message_classes} + host = socket.gethostbyname(socket.gethostname()) + + # Listener socket + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + # Bind and listen for any connections + s.bind((host, self.port)) + s.listen() + + while command not in message_classes: + # Accept new connection + connection, client_address = s.accept() + + # Creeate new stream from the new connection + self.stream = connection.makefile("rb") + + with connection: + if self.verbose: + print("Connection received by " + client_address[0]+"\n") + + env = self.read() + command = env.command + + if env.command in command_to_class: + return command_to_class[command].decode(env.stream()) + + return None + def wait_for(self, *message_classes): command = None - command_to_class = { m.command: m for m in message_classes } + command_to_class = {m.command: m for m in message_classes} # loop until one of the desired commands is encountered while command not in command_to_class: diff --git a/tests/test_network.py b/tests/test_network.py index 8e238f5..cb0d040 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -3,7 +3,7 @@ """ from io import BytesIO -from cryptos.network import NetworkEnvelope +from cryptos.network import NetworkEnvelope, PingMessage from cryptos.network import ( VersionMessage, GetHeadersMessage, @@ -12,6 +12,10 @@ from cryptos.network import SimpleNode from cryptos.block import Block +from multiprocessing import Process +import socket + + def test_encode_decode_network_envelope(): msg = bytes.fromhex('f9beb4d976657261636b000000000000000000005df6e0e2') @@ -28,6 +32,7 @@ def test_encode_decode_network_envelope(): assert envelope.payload == msg[24:] assert envelope.encode() == msg + def test_encode_version_payload(): m = VersionMessage( @@ -38,6 +43,7 @@ def test_encode_version_payload(): assert m.encode().hex() == '7f11010000000000000000000000000000000000000000000000000000000000000000000000ffff00000000208d000000000000000000000000000000000000ffff00000000208d0000000000000000182f70726f6772616d6d696e67626974636f696e3a302e312f0000000000' + def test_encode_getheaders_payload(): block_hex = '0000000000000000001237f46acddf58578a37e213d2a6edc4884a2fcad05ba3' m = GetHeadersMessage( @@ -45,6 +51,7 @@ def test_encode_getheaders_payload(): ) assert m.encode().hex() == '7f11010001a35bd0ca2f4a88c4eda6d213e2378a5758dfcd6af437120000000000000000000000000000000000000000000000000000000000000000000000000000000000' + def test_decode_headers_payload(): hex_msg = '0200000020df3b053dc46f162a9b00c7f0d5124e2676d47bbe7c5d0793a500000000000000ef445fef2ed495c275892206ca533e7411907971013ab83e3b47bd0d692d14d4dc7c835b67d8001ac157e670000000002030eb2540c41025690160a1014c577061596e32e426b712c7ca00000000000000768b89f07044e6130ead292a3f51951adbd2202df447d98789339937fd006bd44880835b67d8001ade09204600' s = BytesIO(bytes.fromhex(hex_msg)) @@ -53,11 +60,38 @@ def test_decode_headers_payload(): for b in headers.blocks: assert isinstance(b, Block) + def test_handshake(): node = SimpleNode( - host='testnet.programmingbitcoin.com', - net='test', + net='test' ) + node.connect("testnet.programmingbitcoin.com") node.handshake() node.close() + + +def listen_node(): + n = SimpleNode(net="test", verbose=1) + message = n.listen(PingMessage) + + assert message.command == b"ping" + assert message.nonce == b"ping" + + n.close() + + +def test_listen(): + p = Process(target=listen_node) + p.start() + + node = SimpleNode( + net='test', + ) + + host = socket.gethostbyname(socket.gethostname()) + node.connect(host) + node.send(PingMessage(b"ping")) + + p.join() + node.close()