pdc_project/main.py
2025-05-27 12:04:14 +02:00

119 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
import argparse
import sys
import numpy as np
import encoder
import decoder
import channel
import subprocess
import pathlib
import os
import tempfile
# Global paths for debugging
INPUT_FILE = None
OUTPUT_FILE = None
def transmit(msg, C):
return encoder.encode_message(msg, C)
def receive_local(c):
return channel.channel(c)
def receive_server(c, hostname, port):
global INPUT_FILE, OUTPUT_FILE
# write or temp-file for input
if INPUT_FILE:
in_name, delete_in = INPUT_FILE, False
np.savetxt(in_name, c)
else:
tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
np.savetxt(tf.name, c)
in_name, delete_in = tf.name, True
tf.close()
# write or temp-file for output
if OUTPUT_FILE:
out_name, delete_out = OUTPUT_FILE, False
else:
fd, out_name = tempfile.mkstemp(suffix='.txt')
os.close(fd)
delete_out = True
# call client.py
cmd = [
sys.executable,
str(pathlib.Path(__file__).parent / 'client.py'),
'--input_file', in_name,
'--output_file', out_name,
'--srv_hostname', hostname,
'--srv_port', str(port)
]
try:
subprocess.run(cmd, check=True)
Y = np.loadtxt(out_name)
finally:
if delete_in and os.path.exists(in_name): os.remove(in_name)
if delete_out and os.path.exists(out_name): os.remove(out_name)
return Y
def receive(c, mode, hostname, port):
if mode == 'local':
return receive_local(c)
elif mode == 'server':
return receive_server(c, hostname, port)
else:
raise ValueError("mode must be 'local' or 'server'")
def test_performance(msg, num_trials, mode, hostname, port):
if len(msg) != 40:
raise ValueError('Message must be exactly 40 characters.')
# build the same codebook
C = encoder.make_codebook(r=5, num_blocks=40)
successes = 0
print(f"Original message: {msg!r}")
print(f"Running {num_trials} trials over '{mode}' channel...\n")
for i in range(num_trials):
# TX → channel → RX
c = transmit(msg, C)
Y = receive(c, mode, hostname, port)
# decode with stateestimation
est, s_est = decoder.decode_blocks_with_state(Y, C)
# count char errors
errors = sum(1 for a,b in zip(est, msg) if a!=b)
if errors == 0:
successes += 1
print(f"Trial {i+1:3d}: state={s_est} decoded={est!r} errors={errors}")
pct = 100 * successes / num_trials
print("\n=== Summary ===")
print(f" Total trials: {num_trials}")
print(f" Perfect decodings: {successes}")
print(f" Overall accuracy: {pct:.2f}%")
def parse_args():
p = argparse.ArgumentParser(description='Test comms system')
p.add_argument('--message','-m', required=True, help='40-char message')
p.add_argument('--trials','-n', type=int, default=200, help='Number of trials')
p.add_argument('--mode', choices=['local','server'], default='local')
p.add_argument('--hostname', default='iscsrv72.epfl.ch')
p.add_argument('--port', type=int, default=80)
p.add_argument('--input_file','-i', default=None,
help='(server mode) where to write input.txt')
p.add_argument('--output_file','-o',default=None,
help='(server mode) where to write output.txt')
return p.parse_args()
if __name__=='__main__':
args = parse_args()
INPUT_FILE = args.input_file
OUTPUT_FILE = args.output_file
test_performance(args.message, args.trials,
args.mode, args.hostname, args.port)