From 0c1737647b246d022ced64257905c581c0f8636e Mon Sep 17 00:00:00 2001 From: appellet Date: Sun, 25 May 2025 16:35:41 +0200 Subject: [PATCH] feat: get 67% correctness --- decoder.py | 60 ++++++++++++++++++++------------------------- encoder.py | 66 +++++++++++++++++++++++++++++++++++++++++--------- test_local.py | 43 +++++++++++++++++++++++--------- test_server.py | 64 ++++++++++++++++++++++++++++++++++-------------- 4 files changed, 157 insertions(+), 76 deletions(-) diff --git a/decoder.py b/decoder.py index 3a82c04..75a0750 100644 --- a/decoder.py +++ b/decoder.py @@ -1,40 +1,32 @@ +# decoder.py import numpy as np -from numpy import logaddexp -from utils import index_to_char -from codebook import construct_codebook +from encoder import ALPHABET, G -def decode_message(Y, codebook): + +def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str: """ - ML decoding for unknown channel state s: - p(Y|i) = 0.5*p(Y|i,s=1) + 0.5*p(Y|i,s=2) - We use log-sum-exp to combine both branch metrics. + Decode received samples by maximum correlation score """ - G = 10 - Y1, Y2 = Y[::2], Y[1::2] - best_idx = None - best_metric = -np.inf - for i, c in enumerate(codebook): - # Only consider indices that map to characters - if i not in index_to_char: - continue - c1, c2 = c[::2], c[1::2] - # Branch metrics (up to additive constants) - s1 = np.sqrt(G) * np.dot(Y1, c1) + np.dot(Y2, c2) - s2 = np.dot(Y1, c1) + np.sqrt(G) * np.dot(Y2, c2) - # Combine via log-sum-exp - metric = logaddexp(s1, s2) - if metric > best_metric: - best_metric = metric - best_idx = i - return best_idx + n = C.shape[1] + half = n // 2 + num = Y.size // n + C1 = C[:, :half] + C2 = C[:, half:] + sqrtG = np.sqrt(G) + recovered = [] + for k in range(num): + Yb = Y[k*n:(k+1)*n] + Ye, Yo = Yb[0::2], Yb[1::2] + s1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T) + s2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T) + score = np.maximum(s1, s2) + best = int(np.argmax(score)) + recovered.append(ALPHABET[best]) + return "".join(recovered) -def signal_to_text(Y, codebook, r=6): - # Reconstruct codebook length (seg_len) - _, seg_len, _, _ = construct_codebook(r, 1) - text = '' - for i in range(40): - seg = Y[i * seg_len:(i + 1) * seg_len] - idx = decode_message(seg, codebook) - text += index_to_char.get(idx, '?') - return text \ No newline at end of file +def count_errors(orig: str, est: str): + """ + List mismatches between orig and est + """ + return [(i, o, e) for i, (o, e) in enumerate(zip(orig, est)) if o != e] \ No newline at end of file diff --git a/encoder.py b/encoder.py index 92e81a8..de75bba 100644 --- a/encoder.py +++ b/encoder.py @@ -1,14 +1,56 @@ -# encoder.py (unchanged except default r) +# encoder.py import numpy as np -from codebook import construct_codebook -from utils import char_to_index, normalize_energy -def text_to_signal(text, r=5, Eb=3): - assert len(text) == 40, "Message must be exactly 40 characters." - codebook, n, m, alpha = construct_codebook(r, Eb) - # Map each character to its codeword - msg_indices = [char_to_index[c] for c in text] - signal = np.concatenate([codebook[i] for i in msg_indices]) - # Enforce the energy constraint - signal = normalize_energy(signal, energy_limit=2000) - return signal, codebook \ No newline at end of file +# System parameters +G = 10.0 # power gain for even samples +ENERGY_LIMIT = 2000.0 # total energy per block +ALPHABET = ( + [chr(i) for i in range(ord('a'), ord('z')+1)] + + [chr(i) for i in range(ord('A'), ord('Z')+1)] + + [str(i) for i in range(10)] + + [' ', '.'] +) +assert len(ALPHABET) == 64, "Alphabet must be size 64" + + +def hadamard(r: int) -> np.ndarray: + if r == 0: + return np.array([[1]], dtype=float) + M = hadamard(r-1) + return np.block([[M, M], [M, -M]]) + + +def Br(r: int) -> np.ndarray: + M = hadamard(r) + return np.vstack([M, -M]) + + +def make_codebook(r: int, num_blocks: int) -> np.ndarray: + """ + Build 64x64 codebook and scale blocks so energy per block ≤ ENERGY_LIMIT/num_blocks + """ + B = Br(r) + C = np.hstack([B, B]).astype(float) + raw_norm = np.sum(C[0]**2) + margin = 0.95 + alpha = margin * (ENERGY_LIMIT / num_blocks) / raw_norm + return np.sqrt(alpha) * C + + +def interleave(c: np.ndarray) -> np.ndarray: + half = c.size // 2 + x = np.empty(c.size) + x[0::2] = c[:half] + x[1::2] = c[half:] + return x + + +def encode_message(msg: str, C: np.ndarray) -> np.ndarray: + """ + Encode 40-character message into interleaved code symbols + """ + if len(msg) != 40: + raise ValueError("Message must be exactly 40 characters.") + idx = [ALPHABET.index(ch) for ch in msg] + blocks = [interleave(C[i]) for i in idx] + return np.concatenate(blocks) \ No newline at end of file diff --git a/test_local.py b/test_local.py index 23a6901..4495240 100644 --- a/test_local.py +++ b/test_local.py @@ -1,17 +1,36 @@ -from encoder import text_to_signal -from decoder import signal_to_text +# test_local.py +#!/usr/bin/env python3 +import argparse +import numpy as np +from encoder import make_codebook, encode_message +from decoder import decode_blocks, count_errors from channel import channel -def test_local(): - message = "HelloWorld123 ThisIsATestMessage12345678" - x, codebook = text_to_signal(message, r=6, Eb=3) - y = channel(x) - decoded = signal_to_text(y, codebook, r=6) - print(f"Original: {message}") - print(f"Decoded : {decoded}") - errors = sum(1 for a, b in zip(message, decoded) if a != b) - print(f"Character errors: {errors}/40") +def main(): + parser = argparse.ArgumentParser(description="Local test using channel.py") + parser.add_argument("--message", required=True, help="40-character message") + args = parser.parse_args() + + msg = args.message + if len(msg) != 40: + raise ValueError("Message must be exactly 40 characters.") + C = make_codebook(r=5, num_blocks=len(msg)) + x = encode_message(msg, C) + print(f"→ Total samples = {x.size}, total energy = {np.sum(x*x):.1f}") + + Y = channel(x) + + msg_hat = decode_blocks(Y, C) + print(f"↓ Decoded message: {msg_hat}") + + errors = count_errors(msg, msg_hat) + print(f"Errors: {len(errors)} / {len(msg)} characters differ") + if errors: + for i, o, e in errors: + print(f" Pos {i}: sent '{o}' but got '{e}'") + else: + print("✔️ No decoding errors!") if __name__ == "__main__": - test_local() + main() diff --git a/test_server.py b/test_server.py index bb4190f..9201646 100644 --- a/test_server.py +++ b/test_server.py @@ -1,28 +1,56 @@ +# test_server.py +#!/usr/bin/env python3 +import argparse import subprocess import numpy as np -from encoder import text_to_signal -from decoder import signal_to_text +from encoder import make_codebook, encode_message +from decoder import decode_blocks, count_errors -def test_server(): - message = "HelloWorld123 ThisIsATestMessage12345678" - x, codebook = text_to_signal(message, r=6, Eb=3) - np.savetxt("input.txt", x, fmt="%.10f") +def call_client(input_path, output_path, host, port): subprocess.run([ "python3", "client.py", - "--input_file", "input.txt", - "--output_file", "output.txt", - "--srv_hostname", "iscsrv72.epfl.ch", - "--srv_port", "80" - ]) + f"--input_file={input_path}", + f"--output_file={output_path}", + f"--srv_hostname={host}", + f"--srv_port={port}" + ], check=True) - y = np.loadtxt("output.txt") - decoded = signal_to_text(y, codebook, r=6) - print(f"Original: {message}") - print(f"Decoded : {decoded}") - errors = sum(1 for a, b in zip(message, decoded) if a != b) - print(f"Character errors: {errors}/40") +def main(): + parser = argparse.ArgumentParser(description="Server test using client.py") + parser.add_argument("--message", required=True, help="40-character message to send") + parser.add_argument("--srv_hostname", default="iscsrv72.epfl.ch", help="Server hostname") + parser.add_argument("--srv_port", type=int, default=80, help="Server port") + args = parser.parse_args() + + msg = args.message + if len(msg) != 40: + raise ValueError("Message must be exactly 40 characters.") + C = make_codebook(r=5, num_blocks=len(msg)) + x = encode_message(msg, C) + + # write encoded symbols to fixed input.txt + input_file = "input.txt" + output_file = "output.txt" + np.savetxt(input_file, x) + + # run client.py to read input.txt and write output.txt + call_client(input_file, output_file, args.srv_hostname, args.srv_port) + + # read received samples + Y = np.loadtxt(output_file) + + msg_hat = decode_blocks(Y, C) + print(f"↓ Decoded message: {msg_hat}") + + errors = count_errors(msg, msg_hat) + print(f"Errors: {len(errors)} / {len(msg)} characters differ") + if errors: + for i, o, e in errors: + print(f" Pos {i}: sent '{o}' but got '{e}'") + else: + print("✔️ No decoding errors!") if __name__ == "__main__": - test_server() \ No newline at end of file + main() \ No newline at end of file