fix: main

This commit is contained in:
appellet 2025-05-27 11:23:30 +02:00
parent a27e7b4adb
commit 094d90715e
4 changed files with 100 additions and 173 deletions

144
main.py
View file

@ -1,58 +1,114 @@
# main.py
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import socket import sys
import numpy as np import numpy as np
import channel_helper as ch import encoder
from encoder import make_codebook, encode_message import decoder
from decoder import decode_blocks, count_errors import channel
from channel import channel as external_channel import subprocess
import tempfile
import pathlib
import os
def send_and_recv(x: np.ndarray, host: str, port: int) -> np.ndarray: def transmit(msg, C):
"""Send samples x to server and receive output via TCP""" """
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: Transmitter: encodes the message into real-valued samples using the codebook C.
sock.connect((host, port)) """
header = b'0' + b'dUV' return encoder.encode_message(msg, C)
ch.send_msg(sock, header, x)
_, Y = ch.recv_msg(sock)
def receive_local(c):
"""
Sends the samples through the local channel simulation.
"""
return channel.channel(c)
def receive_server(c, hostname, port):
"""
Sends the samples to the remote server via client.py and retrieves the output.
"""
# Write input samples to a temporary file
with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as in_f:
np.savetxt(in_f.name, c)
in_name = in_f.name
# Prepare output file
out_fd, out_name = tempfile.mkstemp(suffix='.txt')
os.close(out_fd)
# Invoke client.py
cmd = [
sys.executable,
str(pathlib.Path(__file__).parent / 'client.py'),
'--input_file', in_name,
'--output_file', out_name,
'--srv_hostname', hostname,
'--srv_port', str(port)
]
try:
subprocess.run(cmd, check=True)
Y = np.loadtxt(out_name)
finally:
# Clean up temp files
os.remove(in_name)
os.remove(out_name)
return Y return Y
def main(): def receive(c, mode, hostname, port):
p = argparse.ArgumentParser(description="PDC Tx/Rx local or server") """
p.add_argument("--message", required=True, help="40-character message to send") Wrapper to choose local or server channel.
p.add_argument("--srv_hostname", help="Server hostname") """
p.add_argument("--srv_port", type=int, help="Server port") if mode == 'local':
p.add_argument("--local", action='store_true', help="Use local channel simulation") return receive_local(c)
args = p.parse_args() elif mode == 'server':
return receive_server(c, hostname, port)
else:
raise ValueError("Mode must be 'local' or 'server'")
msg = args.message
def test_performance(msg, num_trials, mode, hostname, port):
"""
Runs num_trials transmissions of msg through the specified channel and reports accuracy.
"""
if len(msg) != 40: if len(msg) != 40:
raise ValueError("Message must be exactly 40 characters.") raise ValueError("Message must be exactly 40 characters.")
num_blocks = len(msg) # Build codebook for 64 symbols, 40 blocks
C = make_codebook(r=6, num_blocks=num_blocks) C = encoder.make_codebook(r=5, num_blocks=40)
x = encode_message(msg, C) successes = 0
print(f"→ Total samples = {x.size}, total energy = {np.sum(x*x):.1f}") for i in range(num_trials):
# Transmit
c = transmit(msg, C)
# Channel
Y = receive(c, mode, hostname, port)
# Decode
est = decoder.decode_blocks(Y, C)
if est == msg:
successes += 1
pct = successes / num_trials * 100
# Display results
print(f"Message: {msg}")
print(f"Trials: {num_trials}")
print(f"Mode: {mode}")
print(f"Correct decodings: {successes}")
print(f"Accuracy: {pct:.2f}%")
if args.local:
print("-- Local simulation mode --")
Y = external_channel(x)
else:
if not args.srv_hostname or not args.srv_port:
raise ValueError("Must specify --srv_hostname and --srv_port unless --local")
Y = send_and_recv(x, args.srv_hostname, args.srv_port)
msg_hat = decode_blocks(Y, C) def parse_args():
print(f"↓ Decoded message: {msg_hat}") parser = argparse.ArgumentParser(description="Test communication system performance.")
parser.add_argument('--message', '-m', type=str, required=True,
help="40-character message to send.")
parser.add_argument('--trials', '-n', type=int, default=1,
help="Number of trials.")
parser.add_argument('--mode', choices=['local','server'], default='local',
help="Channel mode: 'local' or 'server'.")
parser.add_argument('--hostname', type=str, default='iscsrv72.epfl.ch',
help="Server hostname for server mode.")
parser.add_argument('--port', type=int, default=80,
help="Server port for server mode.")
return parser.parse_args()
errors = count_errors(msg, msg_hat) if __name__ == '__main__':
print(f"Errors: {len(errors)} / {len(msg)} characters differ") args = parse_args()
if errors: test_performance(args.message, args.trials, args.mode, args.hostname, args.port)
for i, o, e in errors:
print(f" Pos {i}: sent '{o}' but got '{e}'")
else:
print("✔️ No decoding errors!")
if __name__ == "__main__":
main()

View file

@ -1,37 +0,0 @@
# performance_local.py
#!/usr/bin/env python3
import argparse
import numpy as np
import random
from encoder import make_codebook, encode_message, ALPHABET
from decoder import decode_blocks, count_errors
from channel import channel
def random_message(length):
return ''.join(random.choice(ALPHABET) for _ in range(length))
def main():
parser = argparse.ArgumentParser(description="Monte Carlo evaluation over local channel")
parser.add_argument("--num", type=int, required=True, help="Number of trials")
parser.add_argument("--r", type=int, default=5, help="Hadamard order (default 5)")
args = parser.parse_args()
num_trials = args.num
successes = 0
for _ in range(num_trials):
msg = random_message(40)
C = make_codebook(r=args.r, num_blocks=len(msg))
x = encode_message(msg, C)
Y = channel(x)
msg_hat = decode_blocks(Y, C)
if msg_hat == msg:
successes += 1
ratio = successes / num_trials
print(f"Correctly decoded messages: {successes}/{num_trials} ({ratio:.2%})")
if __name__ == "__main__":
main()

View file

@ -1,36 +0,0 @@
# test_local.py
#!/usr/bin/env python3
import argparse
import numpy as np
from encoder import make_codebook, encode_message
from decoder import decode_blocks, count_errors
from channel import channel
def main():
parser = argparse.ArgumentParser(description="Local test using channel.py")
parser.add_argument("--message", required=True, help="40-character message")
args = parser.parse_args()
msg = args.message
if len(msg) != 40:
raise ValueError("Message must be exactly 40 characters.")
C = make_codebook(r=5, num_blocks=len(msg))
x = encode_message(msg, C)
print(f"→ Total samples = {x.size}, total energy = {np.sum(x*x):.1f}")
Y = channel(x)
msg_hat = decode_blocks(Y, C)
print(f"↓ Decoded message: {msg_hat}")
errors = count_errors(msg, msg_hat)
print(f"Errors: {len(errors)} / {len(msg)} characters differ")
if errors:
for i, o, e in errors:
print(f" Pos {i}: sent '{o}' but got '{e}'")
else:
print("✔️ No decoding errors!")
if __name__ == "__main__":
main()

View file

@ -1,56 +0,0 @@
# test_server.py
#!/usr/bin/env python3
import argparse
import subprocess
import numpy as np
from encoder import make_codebook, encode_message
from decoder import decode_blocks, count_errors
def call_client(input_path, output_path, host, port):
subprocess.run([
"python3", "client.py",
f"--input_file={input_path}",
f"--output_file={output_path}",
f"--srv_hostname={host}",
f"--srv_port={port}"
], check=True)
def main():
parser = argparse.ArgumentParser(description="Server test using client.py")
parser.add_argument("--message", required=True, help="40-character message to send")
parser.add_argument("--srv_hostname", default="iscsrv72.epfl.ch", help="Server hostname")
parser.add_argument("--srv_port", type=int, default=80, help="Server port")
args = parser.parse_args()
msg = args.message
if len(msg) != 40:
raise ValueError("Message must be exactly 40 characters.")
C = make_codebook(r=5, num_blocks=len(msg))
x = encode_message(msg, C)
# write encoded symbols to fixed input.txt
input_file = "input.txt"
output_file = "output.txt"
np.savetxt(input_file, x)
# run client.py to read input.txt and write output.txt
call_client(input_file, output_file, args.srv_hostname, args.srv_port)
# read received samples
Y = np.loadtxt(output_file)
msg_hat = decode_blocks(Y, C)
print(f"↓ Decoded message: {msg_hat}")
errors = count_errors(msg, msg_hat)
print(f"Errors: {len(errors)} / {len(msg)} characters differ")
if errors:
for i, o, e in errors:
print(f" Pos {i}: sent '{o}' but got '{e}'")
else:
print("✔️ No decoding errors!")
if __name__ == "__main__":
main()