From 094d90715e5073d8dcb35c5c9f600153aa8e4456 Mon Sep 17 00:00:00 2001 From: appellet Date: Tue, 27 May 2025 11:23:30 +0200 Subject: [PATCH] fix: main --- main.py | 144 ++++++++++++++++++++++++++++++------------- performance_local.py | 37 ----------- test_local.py | 36 ----------- test_server.py | 56 ----------------- 4 files changed, 100 insertions(+), 173 deletions(-) delete mode 100644 performance_local.py delete mode 100644 test_local.py delete mode 100644 test_server.py diff --git a/main.py b/main.py index d0a33d7..783607c 100644 --- a/main.py +++ b/main.py @@ -1,58 +1,114 @@ -# main.py #!/usr/bin/env python3 + import argparse -import socket +import sys import numpy as np -import channel_helper as ch -from encoder import make_codebook, encode_message -from decoder import decode_blocks, count_errors -from channel import channel as external_channel +import encoder +import decoder +import channel +import subprocess +import tempfile +import pathlib +import os -def send_and_recv(x: np.ndarray, host: str, port: int) -> np.ndarray: - """Send samples x to server and receive output via TCP""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.connect((host, port)) - header = b'0' + b'dUV' - ch.send_msg(sock, header, x) - _, Y = ch.recv_msg(sock) +def transmit(msg, C): + """ + Transmitter: encodes the message into real-valued samples using the codebook C. + """ + return encoder.encode_message(msg, C) + + +def receive_local(c): + """ + Sends the samples through the local channel simulation. + """ + return channel.channel(c) + + +def receive_server(c, hostname, port): + """ + Sends the samples to the remote server via client.py and retrieves the output. + """ + # Write input samples to a temporary file + with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as in_f: + np.savetxt(in_f.name, c) + in_name = in_f.name + # Prepare output file + out_fd, out_name = tempfile.mkstemp(suffix='.txt') + os.close(out_fd) + # Invoke client.py + cmd = [ + sys.executable, + str(pathlib.Path(__file__).parent / 'client.py'), + '--input_file', in_name, + '--output_file', out_name, + '--srv_hostname', hostname, + '--srv_port', str(port) + ] + try: + subprocess.run(cmd, check=True) + Y = np.loadtxt(out_name) + finally: + # Clean up temp files + os.remove(in_name) + os.remove(out_name) return Y -def main(): - p = argparse.ArgumentParser(description="PDC Tx/Rx local or server") - p.add_argument("--message", required=True, help="40-character message to send") - p.add_argument("--srv_hostname", help="Server hostname") - p.add_argument("--srv_port", type=int, help="Server port") - p.add_argument("--local", action='store_true', help="Use local channel simulation") - args = p.parse_args() +def receive(c, mode, hostname, port): + """ + Wrapper to choose local or server channel. + """ + if mode == 'local': + return receive_local(c) + elif mode == 'server': + return receive_server(c, hostname, port) + else: + raise ValueError("Mode must be 'local' or 'server'") - msg = args.message + +def test_performance(msg, num_trials, mode, hostname, port): + """ + Runs num_trials transmissions of msg through the specified channel and reports accuracy. + """ if len(msg) != 40: raise ValueError("Message must be exactly 40 characters.") - num_blocks = len(msg) - C = make_codebook(r=6, num_blocks=num_blocks) - x = encode_message(msg, C) - print(f"→ Total samples = {x.size}, total energy = {np.sum(x*x):.1f}") + # Build codebook for 64 symbols, 40 blocks + C = encoder.make_codebook(r=5, num_blocks=40) + successes = 0 + for i in range(num_trials): + # Transmit + c = transmit(msg, C) + # Channel + Y = receive(c, mode, hostname, port) + # Decode + est = decoder.decode_blocks(Y, C) + if est == msg: + successes += 1 + pct = successes / num_trials * 100 + # Display results + print(f"Message: {msg}") + print(f"Trials: {num_trials}") + print(f"Mode: {mode}") + print(f"Correct decodings: {successes}") + print(f"Accuracy: {pct:.2f}%") - if args.local: - print("-- Local simulation mode --") - Y = external_channel(x) - else: - if not args.srv_hostname or not args.srv_port: - raise ValueError("Must specify --srv_hostname and --srv_port unless --local") - Y = send_and_recv(x, args.srv_hostname, args.srv_port) - msg_hat = decode_blocks(Y, C) - print(f"↓ Decoded message: {msg_hat}") +def parse_args(): + parser = argparse.ArgumentParser(description="Test communication system performance.") + parser.add_argument('--message', '-m', type=str, required=True, + help="40-character message to send.") + parser.add_argument('--trials', '-n', type=int, default=1, + help="Number of trials.") + parser.add_argument('--mode', choices=['local','server'], default='local', + help="Channel mode: 'local' or 'server'.") + parser.add_argument('--hostname', type=str, default='iscsrv72.epfl.ch', + help="Server hostname for server mode.") + parser.add_argument('--port', type=int, default=80, + help="Server port for server mode.") + return parser.parse_args() - errors = count_errors(msg, msg_hat) - print(f"Errors: {len(errors)} / {len(msg)} characters differ") - if errors: - for i, o, e in errors: - print(f" Pos {i}: sent '{o}' but got '{e}'") - else: - print("✔️ No decoding errors!") - -if __name__ == "__main__": - main() \ No newline at end of file +if __name__ == '__main__': + args = parse_args() + test_performance(args.message, args.trials, args.mode, args.hostname, args.port) diff --git a/performance_local.py b/performance_local.py deleted file mode 100644 index 8bac46f..0000000 --- a/performance_local.py +++ /dev/null @@ -1,37 +0,0 @@ -# performance_local.py -#!/usr/bin/env python3 -import argparse -import numpy as np -import random -from encoder import make_codebook, encode_message, ALPHABET -from decoder import decode_blocks, count_errors -from channel import channel - - -def random_message(length): - return ''.join(random.choice(ALPHABET) for _ in range(length)) - - -def main(): - parser = argparse.ArgumentParser(description="Monte Carlo evaluation over local channel") - parser.add_argument("--num", type=int, required=True, help="Number of trials") - parser.add_argument("--r", type=int, default=5, help="Hadamard order (default 5)") - args = parser.parse_args() - - num_trials = args.num - successes = 0 - - for _ in range(num_trials): - msg = random_message(40) - C = make_codebook(r=args.r, num_blocks=len(msg)) - x = encode_message(msg, C) - Y = channel(x) - msg_hat = decode_blocks(Y, C) - if msg_hat == msg: - successes += 1 - - ratio = successes / num_trials - print(f"Correctly decoded messages: {successes}/{num_trials} ({ratio:.2%})") - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/test_local.py b/test_local.py deleted file mode 100644 index 4495240..0000000 --- a/test_local.py +++ /dev/null @@ -1,36 +0,0 @@ -# test_local.py -#!/usr/bin/env python3 -import argparse -import numpy as np -from encoder import make_codebook, encode_message -from decoder import decode_blocks, count_errors -from channel import channel - - -def main(): - parser = argparse.ArgumentParser(description="Local test using channel.py") - parser.add_argument("--message", required=True, help="40-character message") - args = parser.parse_args() - - msg = args.message - if len(msg) != 40: - raise ValueError("Message must be exactly 40 characters.") - C = make_codebook(r=5, num_blocks=len(msg)) - x = encode_message(msg, C) - print(f"→ Total samples = {x.size}, total energy = {np.sum(x*x):.1f}") - - Y = channel(x) - - msg_hat = decode_blocks(Y, C) - print(f"↓ Decoded message: {msg_hat}") - - errors = count_errors(msg, msg_hat) - print(f"Errors: {len(errors)} / {len(msg)} characters differ") - if errors: - for i, o, e in errors: - print(f" Pos {i}: sent '{o}' but got '{e}'") - else: - print("✔️ No decoding errors!") - -if __name__ == "__main__": - main() diff --git a/test_server.py b/test_server.py deleted file mode 100644 index 9201646..0000000 --- a/test_server.py +++ /dev/null @@ -1,56 +0,0 @@ -# test_server.py -#!/usr/bin/env python3 -import argparse -import subprocess -import numpy as np -from encoder import make_codebook, encode_message -from decoder import decode_blocks, count_errors - - -def call_client(input_path, output_path, host, port): - subprocess.run([ - "python3", "client.py", - f"--input_file={input_path}", - f"--output_file={output_path}", - f"--srv_hostname={host}", - f"--srv_port={port}" - ], check=True) - - -def main(): - parser = argparse.ArgumentParser(description="Server test using client.py") - parser.add_argument("--message", required=True, help="40-character message to send") - parser.add_argument("--srv_hostname", default="iscsrv72.epfl.ch", help="Server hostname") - parser.add_argument("--srv_port", type=int, default=80, help="Server port") - args = parser.parse_args() - - msg = args.message - if len(msg) != 40: - raise ValueError("Message must be exactly 40 characters.") - C = make_codebook(r=5, num_blocks=len(msg)) - x = encode_message(msg, C) - - # write encoded symbols to fixed input.txt - input_file = "input.txt" - output_file = "output.txt" - np.savetxt(input_file, x) - - # run client.py to read input.txt and write output.txt - call_client(input_file, output_file, args.srv_hostname, args.srv_port) - - # read received samples - Y = np.loadtxt(output_file) - - msg_hat = decode_blocks(Y, C) - print(f"↓ Decoded message: {msg_hat}") - - errors = count_errors(msg, msg_hat) - print(f"Errors: {len(errors)} / {len(msg)} characters differ") - if errors: - for i, o, e in errors: - print(f" Pos {i}: sent '{o}' but got '{e}'") - else: - print("✔️ No decoding errors!") - -if __name__ == "__main__": - main() \ No newline at end of file