pdc_project/main.py
2025-05-27 11:38:28 +02:00

142 lines
4.6 KiB
Python

#!/usr/bin/env python3
import argparse
import sys
import numpy as np
import encoder
import decoder
import channel
import subprocess
import pathlib
import os
import tempfile
# Global paths for debugging
INPUT_FILE = None
OUTPUT_FILE = None
def transmit(msg, C):
'''
Transmitter: encodes the message into real-valued samples using the codebook C.
'''
return encoder.encode_message(msg, C)
def receive_local(c):
'''
Sends the samples through the local channel simulation.
'''
return channel.channel(c)
def receive_server(c, hostname, port):
'''
Sends the samples to the remote server via client.py. If INPUT_FILE and/or
OUTPUT_FILE are set, uses those filenames (and preserves them); otherwise uses temporary files.
'''
global INPUT_FILE, OUTPUT_FILE
# Determine input file path
if INPUT_FILE:
in_name = INPUT_FILE
np.savetxt(in_name, c)
delete_in = False
else:
with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as in_f:
np.savetxt(in_f.name, c)
in_name = in_f.name
delete_in = True
# Determine output file path
if OUTPUT_FILE:
out_name = OUTPUT_FILE
delete_out = False
else:
out_fd, out_name = tempfile.mkstemp(suffix='.txt')
os.close(out_fd)
delete_out = True
# Invoke client.py
cmd = [
sys.executable,
str(pathlib.Path(__file__).parent / 'client.py'),
'--input_file', in_name,
'--output_file', out_name,
'--srv_hostname', hostname,
'--srv_port', str(port)
]
try:
subprocess.run(cmd, check=True)
Y = np.loadtxt(out_name)
finally:
if delete_in and os.path.exists(in_name):
os.remove(in_name)
if delete_out and os.path.exists(out_name):
os.remove(out_name)
return Y
def receive(c, mode, hostname, port):
'''
Wrapper to choose local or server channel.
'''
if mode == 'local':
return receive_local(c)
elif mode == 'server':
return receive_server(c, hostname, port)
else:
raise ValueError("Mode must be 'local' or 'server'")
def test_performance(msg, num_trials, mode, hostname, port):
'''
Runs num_trials transmissions of msg through the specified channel and reports
per-trial decoded messages and error counts, plus overall accuracy.
'''
if len(msg) != 40:
raise ValueError('Message must be exactly 40 characters.')
# Build codebook for 64 symbols, 40 blocks
C = encoder.make_codebook(r=5, num_blocks=40)
successes = 0
print(f"Original message: {msg}")
print(f"Running {num_trials} trials over '{mode}' channel...\n")
for i in range(num_trials):
# Transmit
c = transmit(msg, C)
# Channel
Y = receive(c, mode, hostname, port)
# Decode
est = decoder.decode_blocks(Y, C)
# Count character errors
errors = sum(1 for a, b in zip(est, msg) if a != b)
# Tally success if no errors
if errors == 0:
successes += 1
# Print trial result
print(f"Trial {i+1:3d}: Decoded: '{est}' | Errors: {errors}")
pct = successes / num_trials * 100
print("\n=== Summary ===")
print(f"Total trials: {num_trials}")
print(f"Perfect decodings: {successes}")
print(f"Overall accuracy: {pct:.2f}%")
def parse_args():
parser = argparse.ArgumentParser(description='Test communication system performance.')
parser.add_argument('--message', '-m', type=str, required=True,
help='40-character message to send.')
parser.add_argument('--trials', '-n', type=int, default=200,
help='Number of trials.')
parser.add_argument('--mode', choices=['local','server'], default='local',
help="Channel mode: 'local' or 'server'.")
parser.add_argument('--hostname', type=str, default='iscsrv72.epfl.ch',
help='Server hostname for server mode.')
parser.add_argument('--port', type=int, default=80,
help='Server port for server mode.')
parser.add_argument('--input_file', '-i', type=str, default=None,
help='Path to write server input samples (input.txt).')
parser.add_argument('--output_file', '-o', type=str, default=None,
help='Path to write server output samples (output.txt).')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
INPUT_FILE = args.input_file
OUTPUT_FILE = args.output_file
test_performance(args.message, args.trials, args.mode, args.hostname, args.port)