feat: get 67% correctness

This commit is contained in:
appellet 2025-05-25 16:35:41 +02:00
parent df631a2199
commit 0c1737647b
4 changed files with 157 additions and 76 deletions

View file

@ -1,40 +1,32 @@
# decoder.py
import numpy as np import numpy as np
from numpy import logaddexp from encoder import ALPHABET, G
from utils import index_to_char
from codebook import construct_codebook
def decode_message(Y, codebook):
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
""" """
ML decoding for unknown channel state s: Decode received samples by maximum correlation score
p(Y|i) = 0.5*p(Y|i,s=1) + 0.5*p(Y|i,s=2)
We use log-sum-exp to combine both branch metrics.
""" """
G = 10 n = C.shape[1]
Y1, Y2 = Y[::2], Y[1::2] half = n // 2
best_idx = None num = Y.size // n
best_metric = -np.inf C1 = C[:, :half]
for i, c in enumerate(codebook): C2 = C[:, half:]
# Only consider indices that map to characters sqrtG = np.sqrt(G)
if i not in index_to_char: recovered = []
continue for k in range(num):
c1, c2 = c[::2], c[1::2] Yb = Y[k*n:(k+1)*n]
# Branch metrics (up to additive constants) Ye, Yo = Yb[0::2], Yb[1::2]
s1 = np.sqrt(G) * np.dot(Y1, c1) + np.dot(Y2, c2) s1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
s2 = np.dot(Y1, c1) + np.sqrt(G) * np.dot(Y2, c2) s2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
# Combine via log-sum-exp score = np.maximum(s1, s2)
metric = logaddexp(s1, s2) best = int(np.argmax(score))
if metric > best_metric: recovered.append(ALPHABET[best])
best_metric = metric return "".join(recovered)
best_idx = i
return best_idx
def signal_to_text(Y, codebook, r=6): def count_errors(orig: str, est: str):
# Reconstruct codebook length (seg_len) """
_, seg_len, _, _ = construct_codebook(r, 1) List mismatches between orig and est
text = '' """
for i in range(40): return [(i, o, e) for i, (o, e) in enumerate(zip(orig, est)) if o != e]
seg = Y[i * seg_len:(i + 1) * seg_len]
idx = decode_message(seg, codebook)
text += index_to_char.get(idx, '?')
return text

View file

@ -1,14 +1,56 @@
# encoder.py (unchanged except default r) # encoder.py
import numpy as np import numpy as np
from codebook import construct_codebook
from utils import char_to_index, normalize_energy
def text_to_signal(text, r=5, Eb=3): # System parameters
assert len(text) == 40, "Message must be exactly 40 characters." G = 10.0 # power gain for even samples
codebook, n, m, alpha = construct_codebook(r, Eb) ENERGY_LIMIT = 2000.0 # total energy per block
# Map each character to its codeword ALPHABET = (
msg_indices = [char_to_index[c] for c in text] [chr(i) for i in range(ord('a'), ord('z')+1)] +
signal = np.concatenate([codebook[i] for i in msg_indices]) [chr(i) for i in range(ord('A'), ord('Z')+1)] +
# Enforce the energy constraint [str(i) for i in range(10)] +
signal = normalize_energy(signal, energy_limit=2000) [' ', '.']
return signal, codebook )
assert len(ALPHABET) == 64, "Alphabet must be size 64"
def hadamard(r: int) -> np.ndarray:
if r == 0:
return np.array([[1]], dtype=float)
M = hadamard(r-1)
return np.block([[M, M], [M, -M]])
def Br(r: int) -> np.ndarray:
M = hadamard(r)
return np.vstack([M, -M])
def make_codebook(r: int, num_blocks: int) -> np.ndarray:
"""
Build 64x64 codebook and scale blocks so energy per block ENERGY_LIMIT/num_blocks
"""
B = Br(r)
C = np.hstack([B, B]).astype(float)
raw_norm = np.sum(C[0]**2)
margin = 0.95
alpha = margin * (ENERGY_LIMIT / num_blocks) / raw_norm
return np.sqrt(alpha) * C
def interleave(c: np.ndarray) -> np.ndarray:
half = c.size // 2
x = np.empty(c.size)
x[0::2] = c[:half]
x[1::2] = c[half:]
return x
def encode_message(msg: str, C: np.ndarray) -> np.ndarray:
"""
Encode 40-character message into interleaved code symbols
"""
if len(msg) != 40:
raise ValueError("Message must be exactly 40 characters.")
idx = [ALPHABET.index(ch) for ch in msg]
blocks = [interleave(C[i]) for i in idx]
return np.concatenate(blocks)

View file

@ -1,17 +1,36 @@
from encoder import text_to_signal # test_local.py
from decoder import signal_to_text #!/usr/bin/env python3
import argparse
import numpy as np
from encoder import make_codebook, encode_message
from decoder import decode_blocks, count_errors
from channel import channel from channel import channel
def test_local():
message = "HelloWorld123 ThisIsATestMessage12345678"
x, codebook = text_to_signal(message, r=6, Eb=3)
y = channel(x)
decoded = signal_to_text(y, codebook, r=6)
print(f"Original: {message}") def main():
print(f"Decoded : {decoded}") parser = argparse.ArgumentParser(description="Local test using channel.py")
errors = sum(1 for a, b in zip(message, decoded) if a != b) parser.add_argument("--message", required=True, help="40-character message")
print(f"Character errors: {errors}/40") args = parser.parse_args()
msg = args.message
if len(msg) != 40:
raise ValueError("Message must be exactly 40 characters.")
C = make_codebook(r=5, num_blocks=len(msg))
x = encode_message(msg, C)
print(f"→ Total samples = {x.size}, total energy = {np.sum(x*x):.1f}")
Y = channel(x)
msg_hat = decode_blocks(Y, C)
print(f"↓ Decoded message: {msg_hat}")
errors = count_errors(msg, msg_hat)
print(f"Errors: {len(errors)} / {len(msg)} characters differ")
if errors:
for i, o, e in errors:
print(f" Pos {i}: sent '{o}' but got '{e}'")
else:
print("✔️ No decoding errors!")
if __name__ == "__main__": if __name__ == "__main__":
test_local() main()

View file

@ -1,28 +1,56 @@
# test_server.py
#!/usr/bin/env python3
import argparse
import subprocess import subprocess
import numpy as np import numpy as np
from encoder import text_to_signal from encoder import make_codebook, encode_message
from decoder import signal_to_text from decoder import decode_blocks, count_errors
def test_server():
message = "HelloWorld123 ThisIsATestMessage12345678"
x, codebook = text_to_signal(message, r=6, Eb=3)
np.savetxt("input.txt", x, fmt="%.10f")
def call_client(input_path, output_path, host, port):
subprocess.run([ subprocess.run([
"python3", "client.py", "python3", "client.py",
"--input_file", "input.txt", f"--input_file={input_path}",
"--output_file", "output.txt", f"--output_file={output_path}",
"--srv_hostname", "iscsrv72.epfl.ch", f"--srv_hostname={host}",
"--srv_port", "80" f"--srv_port={port}"
]) ], check=True)
y = np.loadtxt("output.txt")
decoded = signal_to_text(y, codebook, r=6)
print(f"Original: {message}") def main():
print(f"Decoded : {decoded}") parser = argparse.ArgumentParser(description="Server test using client.py")
errors = sum(1 for a, b in zip(message, decoded) if a != b) parser.add_argument("--message", required=True, help="40-character message to send")
print(f"Character errors: {errors}/40") parser.add_argument("--srv_hostname", default="iscsrv72.epfl.ch", help="Server hostname")
parser.add_argument("--srv_port", type=int, default=80, help="Server port")
args = parser.parse_args()
msg = args.message
if len(msg) != 40:
raise ValueError("Message must be exactly 40 characters.")
C = make_codebook(r=5, num_blocks=len(msg))
x = encode_message(msg, C)
# write encoded symbols to fixed input.txt
input_file = "input.txt"
output_file = "output.txt"
np.savetxt(input_file, x)
# run client.py to read input.txt and write output.txt
call_client(input_file, output_file, args.srv_hostname, args.srv_port)
# read received samples
Y = np.loadtxt(output_file)
msg_hat = decode_blocks(Y, C)
print(f"↓ Decoded message: {msg_hat}")
errors = count_errors(msg, msg_hat)
print(f"Errors: {len(errors)} / {len(msg)} characters differ")
if errors:
for i, o, e in errors:
print(f" Pos {i}: sent '{o}' but got '{e}'")
else:
print("✔️ No decoding errors!")
if __name__ == "__main__": if __name__ == "__main__":
test_server() main()