pdc_project/channel_helper.py
2025-05-19 19:44:29 +02:00

97 lines
2.4 KiB
Python

# ############################################################################
# channel_helper.py for PDC 2025 (DO NOT EDIT!!)
# =========
# Author : Sepand KASHANI [sepand.kashani@epfl.ch]
# ############################################################################
import struct
import numpy as np
import io
def send_msg(sock, header, data):
"""
Send a packet over the network.
Parameters
----------
sock : :py:class:`~socket.socket`
header : bytes
(4,) byte string.
data : :py:class:`~numpy.ndarray`
"""
if len(header) != 4:
raise ValueError('Parameter[header]: expected byte() of length 4.')
with io.BytesIO() as f:
np.save(f, data)
byte_data = f.getvalue()
# Pack message length
msg = (struct.pack('>I', len(header) + len(byte_data)) +
header + byte_data)
sock.sendall(msg)
def recv_msg(sock, N_byte_max=None):
"""
Receive a packet from the network.
Parameters
----------
sock : :py:class:`~socket.socket`
N_byte_max : int
Maximum number of bytes to accept. (None = unlimited.)
:py:class:`RuntimeError` raised if threshold exceeded.
Returns
-------
header : bytes
(4,) byte string
data : :py:class:`~numpy.ndarray`
"""
if (N_byte_max is not None):
if not (N_byte_max > 0):
raise TypeError('Parameter[N_byte_max] must be positive.')
else:
N_byte_max = np.inf
# Extract message length
N_msg_raw = recv_bytes(sock, 4)
N_msg = struct.unpack('>I', N_msg_raw)[0] # bytes
if N_msg > N_byte_max:
ip, port = sock.getpeername()
s_name = f'{ip}:{port}'
raise RuntimeError(f'{s_name} sends {N_msg:>-#9_d} bytes, but N_byte_max={N_byte_max:>-#9_d}.')
msg = recv_bytes(sock, N_msg)
header = msg[:4]
with io.BytesIO(msg[4:]) as f:
data = np.load(f)
return header, data
def recv_bytes(sock, N_byte):
"""
Receive bytes from the network.
Parameters
----------
sock : :py:class:`~socket.socket`
N_byte : int
Number of bytes to read.
Returns
-------
byte_data : bytes
(N_byte,)
"""
packet_size = 2 ** 12
packets, N_byte_read = [], 0
while N_byte_read < N_byte:
packet = sock.recv(min(packet_size, N_byte - N_byte_read))
packets.append(packet)
N_byte_read += len(packet)
byte_data = b''.join(packets)
return byte_data