feat: get 67% correctness
This commit is contained in:
parent
df631a2199
commit
0c1737647b
4 changed files with 157 additions and 76 deletions
60
decoder.py
60
decoder.py
|
@ -1,40 +1,32 @@
|
|||
# decoder.py
|
||||
import numpy as np
|
||||
from numpy import logaddexp
|
||||
from utils import index_to_char
|
||||
from codebook import construct_codebook
|
||||
from encoder import ALPHABET, G
|
||||
|
||||
def decode_message(Y, codebook):
|
||||
|
||||
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
|
||||
"""
|
||||
ML decoding for unknown channel state s:
|
||||
p(Y|i) = 0.5*p(Y|i,s=1) + 0.5*p(Y|i,s=2)
|
||||
We use log-sum-exp to combine both branch metrics.
|
||||
Decode received samples by maximum correlation score
|
||||
"""
|
||||
G = 10
|
||||
Y1, Y2 = Y[::2], Y[1::2]
|
||||
best_idx = None
|
||||
best_metric = -np.inf
|
||||
for i, c in enumerate(codebook):
|
||||
# Only consider indices that map to characters
|
||||
if i not in index_to_char:
|
||||
continue
|
||||
c1, c2 = c[::2], c[1::2]
|
||||
# Branch metrics (up to additive constants)
|
||||
s1 = np.sqrt(G) * np.dot(Y1, c1) + np.dot(Y2, c2)
|
||||
s2 = np.dot(Y1, c1) + np.sqrt(G) * np.dot(Y2, c2)
|
||||
# Combine via log-sum-exp
|
||||
metric = logaddexp(s1, s2)
|
||||
if metric > best_metric:
|
||||
best_metric = metric
|
||||
best_idx = i
|
||||
return best_idx
|
||||
n = C.shape[1]
|
||||
half = n // 2
|
||||
num = Y.size // n
|
||||
C1 = C[:, :half]
|
||||
C2 = C[:, half:]
|
||||
sqrtG = np.sqrt(G)
|
||||
recovered = []
|
||||
for k in range(num):
|
||||
Yb = Y[k*n:(k+1)*n]
|
||||
Ye, Yo = Yb[0::2], Yb[1::2]
|
||||
s1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
|
||||
s2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
|
||||
score = np.maximum(s1, s2)
|
||||
best = int(np.argmax(score))
|
||||
recovered.append(ALPHABET[best])
|
||||
return "".join(recovered)
|
||||
|
||||
|
||||
def signal_to_text(Y, codebook, r=6):
|
||||
# Reconstruct codebook length (seg_len)
|
||||
_, seg_len, _, _ = construct_codebook(r, 1)
|
||||
text = ''
|
||||
for i in range(40):
|
||||
seg = Y[i * seg_len:(i + 1) * seg_len]
|
||||
idx = decode_message(seg, codebook)
|
||||
text += index_to_char.get(idx, '?')
|
||||
return text
|
||||
def count_errors(orig: str, est: str):
|
||||
"""
|
||||
List mismatches between orig and est
|
||||
"""
|
||||
return [(i, o, e) for i, (o, e) in enumerate(zip(orig, est)) if o != e]
|
66
encoder.py
66
encoder.py
|
@ -1,14 +1,56 @@
|
|||
# encoder.py (unchanged except default r)
|
||||
# encoder.py
|
||||
import numpy as np
|
||||
from codebook import construct_codebook
|
||||
from utils import char_to_index, normalize_energy
|
||||
|
||||
def text_to_signal(text, r=5, Eb=3):
|
||||
assert len(text) == 40, "Message must be exactly 40 characters."
|
||||
codebook, n, m, alpha = construct_codebook(r, Eb)
|
||||
# Map each character to its codeword
|
||||
msg_indices = [char_to_index[c] for c in text]
|
||||
signal = np.concatenate([codebook[i] for i in msg_indices])
|
||||
# Enforce the energy constraint
|
||||
signal = normalize_energy(signal, energy_limit=2000)
|
||||
return signal, codebook
|
||||
# System parameters
|
||||
G = 10.0 # power gain for even samples
|
||||
ENERGY_LIMIT = 2000.0 # total energy per block
|
||||
ALPHABET = (
|
||||
[chr(i) for i in range(ord('a'), ord('z')+1)] +
|
||||
[chr(i) for i in range(ord('A'), ord('Z')+1)] +
|
||||
[str(i) for i in range(10)] +
|
||||
[' ', '.']
|
||||
)
|
||||
assert len(ALPHABET) == 64, "Alphabet must be size 64"
|
||||
|
||||
|
||||
def hadamard(r: int) -> np.ndarray:
|
||||
if r == 0:
|
||||
return np.array([[1]], dtype=float)
|
||||
M = hadamard(r-1)
|
||||
return np.block([[M, M], [M, -M]])
|
||||
|
||||
|
||||
def Br(r: int) -> np.ndarray:
|
||||
M = hadamard(r)
|
||||
return np.vstack([M, -M])
|
||||
|
||||
|
||||
def make_codebook(r: int, num_blocks: int) -> np.ndarray:
|
||||
"""
|
||||
Build 64x64 codebook and scale blocks so energy per block ≤ ENERGY_LIMIT/num_blocks
|
||||
"""
|
||||
B = Br(r)
|
||||
C = np.hstack([B, B]).astype(float)
|
||||
raw_norm = np.sum(C[0]**2)
|
||||
margin = 0.95
|
||||
alpha = margin * (ENERGY_LIMIT / num_blocks) / raw_norm
|
||||
return np.sqrt(alpha) * C
|
||||
|
||||
|
||||
def interleave(c: np.ndarray) -> np.ndarray:
|
||||
half = c.size // 2
|
||||
x = np.empty(c.size)
|
||||
x[0::2] = c[:half]
|
||||
x[1::2] = c[half:]
|
||||
return x
|
||||
|
||||
|
||||
def encode_message(msg: str, C: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Encode 40-character message into interleaved code symbols
|
||||
"""
|
||||
if len(msg) != 40:
|
||||
raise ValueError("Message must be exactly 40 characters.")
|
||||
idx = [ALPHABET.index(ch) for ch in msg]
|
||||
blocks = [interleave(C[i]) for i in idx]
|
||||
return np.concatenate(blocks)
|
|
@ -1,17 +1,36 @@
|
|||
from encoder import text_to_signal
|
||||
from decoder import signal_to_text
|
||||
# test_local.py
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import numpy as np
|
||||
from encoder import make_codebook, encode_message
|
||||
from decoder import decode_blocks, count_errors
|
||||
from channel import channel
|
||||
|
||||
def test_local():
|
||||
message = "HelloWorld123 ThisIsATestMessage12345678"
|
||||
x, codebook = text_to_signal(message, r=6, Eb=3)
|
||||
y = channel(x)
|
||||
decoded = signal_to_text(y, codebook, r=6)
|
||||
|
||||
print(f"Original: {message}")
|
||||
print(f"Decoded : {decoded}")
|
||||
errors = sum(1 for a, b in zip(message, decoded) if a != b)
|
||||
print(f"Character errors: {errors}/40")
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Local test using channel.py")
|
||||
parser.add_argument("--message", required=True, help="40-character message")
|
||||
args = parser.parse_args()
|
||||
|
||||
msg = args.message
|
||||
if len(msg) != 40:
|
||||
raise ValueError("Message must be exactly 40 characters.")
|
||||
C = make_codebook(r=5, num_blocks=len(msg))
|
||||
x = encode_message(msg, C)
|
||||
print(f"→ Total samples = {x.size}, total energy = {np.sum(x*x):.1f}")
|
||||
|
||||
Y = channel(x)
|
||||
|
||||
msg_hat = decode_blocks(Y, C)
|
||||
print(f"↓ Decoded message: {msg_hat}")
|
||||
|
||||
errors = count_errors(msg, msg_hat)
|
||||
print(f"Errors: {len(errors)} / {len(msg)} characters differ")
|
||||
if errors:
|
||||
for i, o, e in errors:
|
||||
print(f" Pos {i}: sent '{o}' but got '{e}'")
|
||||
else:
|
||||
print("✔️ No decoding errors!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_local()
|
||||
main()
|
||||
|
|
|
@ -1,28 +1,56 @@
|
|||
# test_server.py
|
||||
#!/usr/bin/env python3
|
||||
import argparse
|
||||
import subprocess
|
||||
import numpy as np
|
||||
from encoder import text_to_signal
|
||||
from decoder import signal_to_text
|
||||
from encoder import make_codebook, encode_message
|
||||
from decoder import decode_blocks, count_errors
|
||||
|
||||
def test_server():
|
||||
message = "HelloWorld123 ThisIsATestMessage12345678"
|
||||
x, codebook = text_to_signal(message, r=6, Eb=3)
|
||||
np.savetxt("input.txt", x, fmt="%.10f")
|
||||
|
||||
def call_client(input_path, output_path, host, port):
|
||||
subprocess.run([
|
||||
"python3", "client.py",
|
||||
"--input_file", "input.txt",
|
||||
"--output_file", "output.txt",
|
||||
"--srv_hostname", "iscsrv72.epfl.ch",
|
||||
"--srv_port", "80"
|
||||
])
|
||||
f"--input_file={input_path}",
|
||||
f"--output_file={output_path}",
|
||||
f"--srv_hostname={host}",
|
||||
f"--srv_port={port}"
|
||||
], check=True)
|
||||
|
||||
y = np.loadtxt("output.txt")
|
||||
decoded = signal_to_text(y, codebook, r=6)
|
||||
|
||||
print(f"Original: {message}")
|
||||
print(f"Decoded : {decoded}")
|
||||
errors = sum(1 for a, b in zip(message, decoded) if a != b)
|
||||
print(f"Character errors: {errors}/40")
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Server test using client.py")
|
||||
parser.add_argument("--message", required=True, help="40-character message to send")
|
||||
parser.add_argument("--srv_hostname", default="iscsrv72.epfl.ch", help="Server hostname")
|
||||
parser.add_argument("--srv_port", type=int, default=80, help="Server port")
|
||||
args = parser.parse_args()
|
||||
|
||||
msg = args.message
|
||||
if len(msg) != 40:
|
||||
raise ValueError("Message must be exactly 40 characters.")
|
||||
C = make_codebook(r=5, num_blocks=len(msg))
|
||||
x = encode_message(msg, C)
|
||||
|
||||
# write encoded symbols to fixed input.txt
|
||||
input_file = "input.txt"
|
||||
output_file = "output.txt"
|
||||
np.savetxt(input_file, x)
|
||||
|
||||
# run client.py to read input.txt and write output.txt
|
||||
call_client(input_file, output_file, args.srv_hostname, args.srv_port)
|
||||
|
||||
# read received samples
|
||||
Y = np.loadtxt(output_file)
|
||||
|
||||
msg_hat = decode_blocks(Y, C)
|
||||
print(f"↓ Decoded message: {msg_hat}")
|
||||
|
||||
errors = count_errors(msg, msg_hat)
|
||||
print(f"Errors: {len(errors)} / {len(msg)} characters differ")
|
||||
if errors:
|
||||
for i, o, e in errors:
|
||||
print(f" Pos {i}: sent '{o}' but got '{e}'")
|
||||
else:
|
||||
print("✔️ No decoding errors!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_server()
|
||||
main()
|
Loading…
Add table
Reference in a new issue