119 lines
3.6 KiB
Python
119 lines
3.6 KiB
Python
#!/usr/bin/env python3
|
||
|
||
import argparse
|
||
import sys
|
||
import numpy as np
|
||
import encoder
|
||
import decoder
|
||
import channel
|
||
import subprocess
|
||
import pathlib
|
||
import os
|
||
import tempfile
|
||
|
||
# Global paths for debugging
|
||
INPUT_FILE = None
|
||
OUTPUT_FILE = None
|
||
|
||
def transmit(msg, C):
|
||
return encoder.encode_message(msg, C)
|
||
|
||
def receive_local(c):
|
||
return channel.channel(c)
|
||
|
||
def receive_server(c, hostname, port):
|
||
global INPUT_FILE, OUTPUT_FILE
|
||
# write or temp-file for input
|
||
if INPUT_FILE:
|
||
in_name, delete_in = INPUT_FILE, False
|
||
np.savetxt(in_name, c)
|
||
else:
|
||
tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
|
||
np.savetxt(tf.name, c)
|
||
in_name, delete_in = tf.name, True
|
||
tf.close()
|
||
# write or temp-file for output
|
||
if OUTPUT_FILE:
|
||
out_name, delete_out = OUTPUT_FILE, False
|
||
else:
|
||
fd, out_name = tempfile.mkstemp(suffix='.txt')
|
||
os.close(fd)
|
||
delete_out = True
|
||
|
||
# call client.py
|
||
cmd = [
|
||
sys.executable,
|
||
str(pathlib.Path(__file__).parent / 'client.py'),
|
||
'--input_file', in_name,
|
||
'--output_file', out_name,
|
||
'--srv_hostname', hostname,
|
||
'--srv_port', str(port)
|
||
]
|
||
try:
|
||
subprocess.run(cmd, check=True)
|
||
Y = np.loadtxt(out_name)
|
||
finally:
|
||
if delete_in and os.path.exists(in_name): os.remove(in_name)
|
||
if delete_out and os.path.exists(out_name): os.remove(out_name)
|
||
|
||
return Y
|
||
|
||
def receive(c, mode, hostname, port):
|
||
if mode == 'local':
|
||
return receive_local(c)
|
||
elif mode == 'server':
|
||
return receive_server(c, hostname, port)
|
||
else:
|
||
raise ValueError("mode must be 'local' or 'server'")
|
||
|
||
def test_performance(msg, num_trials, mode, hostname, port):
|
||
if len(msg) != 40:
|
||
raise ValueError('Message must be exactly 40 characters.')
|
||
|
||
# build the same codebook
|
||
C = encoder.make_codebook(r=5, num_blocks=40)
|
||
|
||
successes = 0
|
||
print(f"Original message: {msg!r}")
|
||
print(f"Running {num_trials} trials over '{mode}' channel...\n")
|
||
|
||
for i in range(num_trials):
|
||
# TX → channel → RX
|
||
c = transmit(msg, C)
|
||
Y = receive(c, mode, hostname, port)
|
||
|
||
# decode with state‐estimation
|
||
est, s_est = decoder.decode_blocks_with_state(Y, C)
|
||
|
||
# count char errors
|
||
errors = sum(1 for a,b in zip(est, msg) if a!=b)
|
||
if errors == 0:
|
||
successes += 1
|
||
|
||
print(f"Trial {i+1:3d}: state={s_est} decoded={est!r} errors={errors}")
|
||
|
||
pct = 100 * successes / num_trials
|
||
print("\n=== Summary ===")
|
||
print(f" Total trials: {num_trials}")
|
||
print(f" Perfect decodings: {successes}")
|
||
print(f" Overall accuracy: {pct:.2f}%")
|
||
|
||
def parse_args():
|
||
p = argparse.ArgumentParser(description='Test comms system')
|
||
p.add_argument('--message','-m', required=True, help='40-char message')
|
||
p.add_argument('--trials','-n', type=int, default=200, help='Number of trials')
|
||
p.add_argument('--mode', choices=['local','server'], default='local')
|
||
p.add_argument('--hostname', default='iscsrv72.epfl.ch')
|
||
p.add_argument('--port', type=int, default=80)
|
||
p.add_argument('--input_file','-i', default=None,
|
||
help='(server mode) where to write input.txt')
|
||
p.add_argument('--output_file','-o',default=None,
|
||
help='(server mode) where to write output.txt')
|
||
return p.parse_args()
|
||
|
||
if __name__=='__main__':
|
||
args = parse_args()
|
||
INPUT_FILE = args.input_file
|
||
OUTPUT_FILE = args.output_file
|
||
test_performance(args.message, args.trials,
|
||
args.mode, args.hostname, args.port)
|