fix: main
This commit is contained in:
parent
a27e7b4adb
commit
094d90715e
4 changed files with 100 additions and 173 deletions
144
main.py
144
main.py
|
@ -1,58 +1,114 @@
|
||||||
# main.py
|
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import socket
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import channel_helper as ch
|
import encoder
|
||||||
from encoder import make_codebook, encode_message
|
import decoder
|
||||||
from decoder import decode_blocks, count_errors
|
import channel
|
||||||
from channel import channel as external_channel
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import pathlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
def send_and_recv(x: np.ndarray, host: str, port: int) -> np.ndarray:
|
def transmit(msg, C):
|
||||||
"""Send samples x to server and receive output via TCP"""
|
"""
|
||||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
Transmitter: encodes the message into real-valued samples using the codebook C.
|
||||||
sock.connect((host, port))
|
"""
|
||||||
header = b'0' + b'dUV'
|
return encoder.encode_message(msg, C)
|
||||||
ch.send_msg(sock, header, x)
|
|
||||||
_, Y = ch.recv_msg(sock)
|
|
||||||
|
def receive_local(c):
|
||||||
|
"""
|
||||||
|
Sends the samples through the local channel simulation.
|
||||||
|
"""
|
||||||
|
return channel.channel(c)
|
||||||
|
|
||||||
|
|
||||||
|
def receive_server(c, hostname, port):
|
||||||
|
"""
|
||||||
|
Sends the samples to the remote server via client.py and retrieves the output.
|
||||||
|
"""
|
||||||
|
# Write input samples to a temporary file
|
||||||
|
with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as in_f:
|
||||||
|
np.savetxt(in_f.name, c)
|
||||||
|
in_name = in_f.name
|
||||||
|
# Prepare output file
|
||||||
|
out_fd, out_name = tempfile.mkstemp(suffix='.txt')
|
||||||
|
os.close(out_fd)
|
||||||
|
# Invoke client.py
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
str(pathlib.Path(__file__).parent / 'client.py'),
|
||||||
|
'--input_file', in_name,
|
||||||
|
'--output_file', out_name,
|
||||||
|
'--srv_hostname', hostname,
|
||||||
|
'--srv_port', str(port)
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
Y = np.loadtxt(out_name)
|
||||||
|
finally:
|
||||||
|
# Clean up temp files
|
||||||
|
os.remove(in_name)
|
||||||
|
os.remove(out_name)
|
||||||
return Y
|
return Y
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def receive(c, mode, hostname, port):
|
||||||
p = argparse.ArgumentParser(description="PDC Tx/Rx local or server")
|
"""
|
||||||
p.add_argument("--message", required=True, help="40-character message to send")
|
Wrapper to choose local or server channel.
|
||||||
p.add_argument("--srv_hostname", help="Server hostname")
|
"""
|
||||||
p.add_argument("--srv_port", type=int, help="Server port")
|
if mode == 'local':
|
||||||
p.add_argument("--local", action='store_true', help="Use local channel simulation")
|
return receive_local(c)
|
||||||
args = p.parse_args()
|
elif mode == 'server':
|
||||||
|
return receive_server(c, hostname, port)
|
||||||
|
else:
|
||||||
|
raise ValueError("Mode must be 'local' or 'server'")
|
||||||
|
|
||||||
msg = args.message
|
|
||||||
|
def test_performance(msg, num_trials, mode, hostname, port):
|
||||||
|
"""
|
||||||
|
Runs num_trials transmissions of msg through the specified channel and reports accuracy.
|
||||||
|
"""
|
||||||
if len(msg) != 40:
|
if len(msg) != 40:
|
||||||
raise ValueError("Message must be exactly 40 characters.")
|
raise ValueError("Message must be exactly 40 characters.")
|
||||||
num_blocks = len(msg)
|
# Build codebook for 64 symbols, 40 blocks
|
||||||
C = make_codebook(r=6, num_blocks=num_blocks)
|
C = encoder.make_codebook(r=5, num_blocks=40)
|
||||||
x = encode_message(msg, C)
|
successes = 0
|
||||||
print(f"→ Total samples = {x.size}, total energy = {np.sum(x*x):.1f}")
|
for i in range(num_trials):
|
||||||
|
# Transmit
|
||||||
|
c = transmit(msg, C)
|
||||||
|
# Channel
|
||||||
|
Y = receive(c, mode, hostname, port)
|
||||||
|
# Decode
|
||||||
|
est = decoder.decode_blocks(Y, C)
|
||||||
|
if est == msg:
|
||||||
|
successes += 1
|
||||||
|
pct = successes / num_trials * 100
|
||||||
|
# Display results
|
||||||
|
print(f"Message: {msg}")
|
||||||
|
print(f"Trials: {num_trials}")
|
||||||
|
print(f"Mode: {mode}")
|
||||||
|
print(f"Correct decodings: {successes}")
|
||||||
|
print(f"Accuracy: {pct:.2f}%")
|
||||||
|
|
||||||
if args.local:
|
|
||||||
print("-- Local simulation mode --")
|
|
||||||
Y = external_channel(x)
|
|
||||||
else:
|
|
||||||
if not args.srv_hostname or not args.srv_port:
|
|
||||||
raise ValueError("Must specify --srv_hostname and --srv_port unless --local")
|
|
||||||
Y = send_and_recv(x, args.srv_hostname, args.srv_port)
|
|
||||||
|
|
||||||
msg_hat = decode_blocks(Y, C)
|
def parse_args():
|
||||||
print(f"↓ Decoded message: {msg_hat}")
|
parser = argparse.ArgumentParser(description="Test communication system performance.")
|
||||||
|
parser.add_argument('--message', '-m', type=str, required=True,
|
||||||
|
help="40-character message to send.")
|
||||||
|
parser.add_argument('--trials', '-n', type=int, default=1,
|
||||||
|
help="Number of trials.")
|
||||||
|
parser.add_argument('--mode', choices=['local','server'], default='local',
|
||||||
|
help="Channel mode: 'local' or 'server'.")
|
||||||
|
parser.add_argument('--hostname', type=str, default='iscsrv72.epfl.ch',
|
||||||
|
help="Server hostname for server mode.")
|
||||||
|
parser.add_argument('--port', type=int, default=80,
|
||||||
|
help="Server port for server mode.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
errors = count_errors(msg, msg_hat)
|
if __name__ == '__main__':
|
||||||
print(f"Errors: {len(errors)} / {len(msg)} characters differ")
|
args = parse_args()
|
||||||
if errors:
|
test_performance(args.message, args.trials, args.mode, args.hostname, args.port)
|
||||||
for i, o, e in errors:
|
|
||||||
print(f" Pos {i}: sent '{o}' but got '{e}'")
|
|
||||||
else:
|
|
||||||
print("✔️ No decoding errors!")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
|
|
|
@ -1,37 +0,0 @@
|
||||||
# performance_local.py
|
|
||||||
#!/usr/bin/env python3
|
|
||||||
import argparse
|
|
||||||
import numpy as np
|
|
||||||
import random
|
|
||||||
from encoder import make_codebook, encode_message, ALPHABET
|
|
||||||
from decoder import decode_blocks, count_errors
|
|
||||||
from channel import channel
|
|
||||||
|
|
||||||
|
|
||||||
def random_message(length):
|
|
||||||
return ''.join(random.choice(ALPHABET) for _ in range(length))
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Monte Carlo evaluation over local channel")
|
|
||||||
parser.add_argument("--num", type=int, required=True, help="Number of trials")
|
|
||||||
parser.add_argument("--r", type=int, default=5, help="Hadamard order (default 5)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
num_trials = args.num
|
|
||||||
successes = 0
|
|
||||||
|
|
||||||
for _ in range(num_trials):
|
|
||||||
msg = random_message(40)
|
|
||||||
C = make_codebook(r=args.r, num_blocks=len(msg))
|
|
||||||
x = encode_message(msg, C)
|
|
||||||
Y = channel(x)
|
|
||||||
msg_hat = decode_blocks(Y, C)
|
|
||||||
if msg_hat == msg:
|
|
||||||
successes += 1
|
|
||||||
|
|
||||||
ratio = successes / num_trials
|
|
||||||
print(f"Correctly decoded messages: {successes}/{num_trials} ({ratio:.2%})")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,36 +0,0 @@
|
||||||
# test_local.py
|
|
||||||
#!/usr/bin/env python3
|
|
||||||
import argparse
|
|
||||||
import numpy as np
|
|
||||||
from encoder import make_codebook, encode_message
|
|
||||||
from decoder import decode_blocks, count_errors
|
|
||||||
from channel import channel
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Local test using channel.py")
|
|
||||||
parser.add_argument("--message", required=True, help="40-character message")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
msg = args.message
|
|
||||||
if len(msg) != 40:
|
|
||||||
raise ValueError("Message must be exactly 40 characters.")
|
|
||||||
C = make_codebook(r=5, num_blocks=len(msg))
|
|
||||||
x = encode_message(msg, C)
|
|
||||||
print(f"→ Total samples = {x.size}, total energy = {np.sum(x*x):.1f}")
|
|
||||||
|
|
||||||
Y = channel(x)
|
|
||||||
|
|
||||||
msg_hat = decode_blocks(Y, C)
|
|
||||||
print(f"↓ Decoded message: {msg_hat}")
|
|
||||||
|
|
||||||
errors = count_errors(msg, msg_hat)
|
|
||||||
print(f"Errors: {len(errors)} / {len(msg)} characters differ")
|
|
||||||
if errors:
|
|
||||||
for i, o, e in errors:
|
|
||||||
print(f" Pos {i}: sent '{o}' but got '{e}'")
|
|
||||||
else:
|
|
||||||
print("✔️ No decoding errors!")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
|
@ -1,56 +0,0 @@
|
||||||
# test_server.py
|
|
||||||
#!/usr/bin/env python3
|
|
||||||
import argparse
|
|
||||||
import subprocess
|
|
||||||
import numpy as np
|
|
||||||
from encoder import make_codebook, encode_message
|
|
||||||
from decoder import decode_blocks, count_errors
|
|
||||||
|
|
||||||
|
|
||||||
def call_client(input_path, output_path, host, port):
|
|
||||||
subprocess.run([
|
|
||||||
"python3", "client.py",
|
|
||||||
f"--input_file={input_path}",
|
|
||||||
f"--output_file={output_path}",
|
|
||||||
f"--srv_hostname={host}",
|
|
||||||
f"--srv_port={port}"
|
|
||||||
], check=True)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Server test using client.py")
|
|
||||||
parser.add_argument("--message", required=True, help="40-character message to send")
|
|
||||||
parser.add_argument("--srv_hostname", default="iscsrv72.epfl.ch", help="Server hostname")
|
|
||||||
parser.add_argument("--srv_port", type=int, default=80, help="Server port")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
msg = args.message
|
|
||||||
if len(msg) != 40:
|
|
||||||
raise ValueError("Message must be exactly 40 characters.")
|
|
||||||
C = make_codebook(r=5, num_blocks=len(msg))
|
|
||||||
x = encode_message(msg, C)
|
|
||||||
|
|
||||||
# write encoded symbols to fixed input.txt
|
|
||||||
input_file = "input.txt"
|
|
||||||
output_file = "output.txt"
|
|
||||||
np.savetxt(input_file, x)
|
|
||||||
|
|
||||||
# run client.py to read input.txt and write output.txt
|
|
||||||
call_client(input_file, output_file, args.srv_hostname, args.srv_port)
|
|
||||||
|
|
||||||
# read received samples
|
|
||||||
Y = np.loadtxt(output_file)
|
|
||||||
|
|
||||||
msg_hat = decode_blocks(Y, C)
|
|
||||||
print(f"↓ Decoded message: {msg_hat}")
|
|
||||||
|
|
||||||
errors = count_errors(msg, msg_hat)
|
|
||||||
print(f"Errors: {len(errors)} / {len(msg)} characters differ")
|
|
||||||
if errors:
|
|
||||||
for i, o, e in errors:
|
|
||||||
print(f" Pos {i}: sent '{o}' but got '{e}'")
|
|
||||||
else:
|
|
||||||
print("✔️ No decoding errors!")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
Loading…
Add table
Reference in a new issue