feat: improve code

This commit is contained in:
appellet 2025-05-30 02:00:30 +02:00
parent 18aca25240
commit fdb07e1c00
8 changed files with 159 additions and 417 deletions

View file

@ -66,6 +66,24 @@ python3 main_backup.py \
> After running, `input.txt` contains your transmitted samples, and `output.txt` contains the noisy output from the server.
### 4. Create input.txt
```bash
python3 encoder.py "message_40_characters"
```
### 5. Create output.txt throught the channel
```bash
python3 channel.py
```
### 6. Decode the output.tct
```bash
python3 decoder.py
```
---
## Manual decoding of server output

View file

@ -18,3 +18,12 @@ def channel(x):
Y[::2] += x_even
Y[1::2] += x_odd
return Y
if __name__ == "__main__":
# Read input.txt
x = np.loadtxt("input.txt")
# Pass through channel
y = channel(x)
# Write output.txt
np.savetxt("output.txt", y)
print("Channel output written to output.txt")

View file

@ -1,43 +1,49 @@
import numpy as np
from encoder import (
G, CHAR_SET, pair_to_index, index_to_pair, make_codebook
)
print(np.__config__.show())
from encoder import G, ALPHABET, make_codebook
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray):
n_cw = C.shape[1] * 2 # 8192 samples after duplication
N = Y.size
assert N % 2 == 0
assert N <= 1_000_000
n_cw = C.shape[1] * 2
nb = Y.size // n_cw
assert nb * n_cw == Y.size, "length mismatch"
assert nb * n_cw == Y.size, "length mismatch between Y and codebook"
Y_blocks = Y.reshape(nb, n_cw)
Y_even = Y_blocks[:, ::2].astype(np.float32)
Y_odd = Y_blocks[:, 1::2].astype(np.float32)
C = C.astype(np.float32)
sqrtg = np.sqrt(G)
# --- Vectorize even/odd block extraction done for speedup because nupy is dumb
Y_blocks = Y.reshape(nb, n_cw)
Y_even = Y_blocks[:, ::2] # shape (nb, 4096)
Y_odd = Y_blocks[:, 1::2] # shape (nb, 4096)
#print(f"Extracted {nb} blocks of {n_cw} samples each, ")
# --- Vectorize all block scoring
# Each: (nb, 4096) = (nb, 4096) @ (4096, 4096).T
#print(Y_blocks.dtype, C.dtype, sqrtg.dtype, Y_even.dtype, Y_odd.dtype)
Y_even = Y_even.astype(np.float32, copy=False)
Y_odd = Y_odd.astype(np.float32, copy=False)
C = C.astype(np.float32, copy=False)
s1 = sqrtg * (Y_even @ C.T) + (Y_odd @ C.T)
s2 = (Y_even @ C.T) + sqrtg * (Y_odd @ C.T)
#print(f"Scoring {nb} blocks with {C.shape[0]} codewords, each of length {C.shape[1]}.")
best_if_s1 = np.argmax(s1, axis=1)
best_if_s2 = np.argmax(s2, axis=1)
tot1 = np.sum(np.max(s1, axis=1))
tot2 = np.sum(np.max(s2, axis=1))
state = 1 if tot1 >= tot2 else 2
indices = best_if_s1 if state == 1 else best_if_s2
chars = [index_to_pair(i) for i in indices]
decoded = ''.join(a+b for a, b in chars)
msg_chars = []
for idx in indices:
first = ALPHABET[idx >> 6]
second = ALPHABET[idx & 0x3F]
msg_chars.append(first)
msg_chars.append(second)
decoded = ''.join(msg_chars)
return decoded, state
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
return decode_blocks_with_state(Y, C)[0]
if __name__ == "__main__":
Y = np.loadtxt("output.txt")
C, _ = make_codebook()
decoded, state = decode_blocks_with_state(Y, C)
print(f"Decoded message: {decoded}")
print(f"Detected channel state: {state}")

View file

@ -1,103 +0,0 @@
# decoder_backup.py
import numpy as np
from encoder_backup import ALPHABET, G
# Match the channels noise variance
SIGMA2 = 10.0
def _block_ll_scores(Yb: np.ndarray,
C1: np.ndarray,
C2: np.ndarray,
sqrtG: float
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute per-symbol log-likelihood scores for one interleaved block Yb
under the two channel states (even-boosted vs. odd-boosted).
Returns (scores_state1, scores_state2).
"""
# Split received block into even/odd samples
Ye, Yo = Yb[0::2], Yb[1::2]
# Precompute the squared-norm penalties for each codeword half
# (these come from the -||Y - H_s C||^2 term)
# state1: even half is √G * C1, odd half is C2
E1 = G * np.sum(C1**2, axis=1) + np.sum(C2**2, axis=1)
# state2: even half is C1, odd half is √G * C2
E2 = np.sum(C1**2, axis=1) + G * np.sum(C2**2, axis=1)
# Correlation terms
corr1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
corr2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
# ML loglikelihood (up to constant 1/(2σ²)):
scores1 = (corr1 - 0.5 * E1) / SIGMA2
scores2 = (corr2 - 0.5 * E2) / SIGMA2
return scores1, scores2
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
"""
Per-block ML decoding marginalizing over the unknown state:
for each block, compute scores1/2 via _block_ll_scores, then
marginal_score = logaddexp(scores1, scores2) and pick argmax.
"""
n = C.shape[1]
assert Y.size % n == 0, "Y length must be a multiple of codeword length"
num_blocks = Y.size // n
half = n // 2
C1, C2 = C[:, :half], C[:, half:]
sqrtG = np.sqrt(G)
recovered = []
for k in range(num_blocks):
Yb = Y[k*n:(k+1)*n]
s1, s2 = _block_ll_scores(Yb, C1, C2, sqrtG)
# marginal log-likelihood per symbol
marg = np.logaddexp(s1, s2)
best = int(np.argmax(marg))
recovered.append(ALPHABET[best])
return "".join(recovered)
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray) -> (str, int):
"""
JointML state estimation and decoding:
- For each block, get per-state scores via _block_ll_scores
- Pick the best symbol index under each state; sum those logLs across blocks
- Choose the state with the higher total logL
- Reconstruct the string using the bestsymbol indices for that state
Returns (decoded_string, estimated_state).
"""
n = C.shape[1]
assert Y.size % n == 0, "Y length must be a multiple of codeword length"
num_blocks = Y.size // n
half = n // 2
C1, C2 = C[:, :half], C[:, half:]
sqrtG = np.sqrt(G)
total1, total2 = 0.0, 0.0
best1, best2 = [], []
for k in range(num_blocks):
Yb = Y[k*n:(k+1)*n]
s1, s2 = _block_ll_scores(Yb, C1, C2, sqrtG)
idx1 = int(np.argmax(s1))
idx2 = int(np.argmax(s2))
total1 += s1[idx1]
total2 += s2[idx2]
best1.append(idx1)
best2.append(idx2)
s_est = 1 if total1 >= total2 else 2
chosen = best1 if s_est == 1 else best2
decoded = "".join(ALPHABET[i] for i in chosen)
return decoded, s_est

View file

@ -1,82 +1,65 @@
# encoder.py — 2-char/1-codeword implementation
import numpy as np
from typing import Tuple
import sys
##############################################################################
# Public constants
##############################################################################
CHAR_SET = (
[chr(i) for i in range(ord('a'), ord('z')+1)] +
[chr(i) for i in range(ord('A'), ord('Z')+1)] +
[str(i) for i in range(10)] +
[' ', '.']
)
assert len(CHAR_SET) == 64
CHAR_TO_IDX = {c: i for i, c in enumerate(CHAR_SET)}
IDX_TO_CHAR = {i: c for c, i in CHAR_TO_IDX.items()}
ALPHABET = "abcdefghijklmnopqrstuvwxyz" \
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" \
"0123456789 ."
G = 10.0 # channel gain
ENERGY_LIMIT = 2000.0 # global limit ‖x‖²
TEXT_LEN = 40 # must stay 40
CHAR_TO_IDX = {c: i for i, c in enumerate(ALPHABET)}
##############################################################################
# Hadamard-codebook utilities
##############################################################################
def _hadamard(r: int) -> np.ndarray:
G = 10.0
ENERGY_LIMIT = 2000.0
TEXT_LENGTH = 40
ALPHABET_LENGTH = len(ALPHABET)
assert ALPHABET_LENGTH == 64
def pair_to_index(a: str, b: str) -> int:
i1 = CHAR_TO_IDX[a]
i2 = CHAR_TO_IDX[b]
return (i1 << 6) + i2
def hadamard(r: int) -> np.ndarray:
if r == 0:
return np.array([[1.]], dtype=np.float32)
H = _hadamard(r-1)
return np.block([[H, H],
[H, -H]])
H = hadamard(r - 1)
return np.block([[H, H], [H, -H]])
def _Br(r: int) -> np.ndarray:
H = _hadamard(r)
return np.vstack([H, -H]) # 2^(r+1) × 2^r
def Br(r: int) -> np.ndarray:
H = hadamard(r)
return np.vstack([H, -H])
##############################################################################
# Public API
##############################################################################
def make_codebook(r: int = 11,
num_blocks: int = TEXT_LEN//2,
energy_budget: float = ENERGY_LIMIT
) -> Tuple[np.ndarray, float]:
"""
Builds the scaled codebook C (4096×4096) used by both encoder & decoder.
α is chosen so that **after the per-sample duplication** in encode_message,
a full 20-block message consumes exactly `energy_budget`.
"""
B = _Br(r) # 4096 × 2048
C = np.hstack((B, B)).astype(np.float32) # 4096 × 4096
n = C.shape[1] # 4096
dup_factor = 2 # sample-duplication
alpha = energy_budget / (num_blocks * dup_factor * n)
def make_codebook(r: int = 11, num_blocks: int = TEXT_LENGTH // 2, Eb: float = ENERGY_LIMIT):
B = Br(r)
C = np.hstack((B, B)).astype(np.float32)
n = C.shape[1]
alpha = Eb / (num_blocks * 2 * n)
C *= np.sqrt(alpha, dtype=C.dtype)
return C, alpha
def pair_to_index(a: str, b: str) -> int:
return 64 * CHAR_TO_IDX[a] + CHAR_TO_IDX[b]
def index_to_pair(k: int) -> tuple[str, str]:
if not 0 <= k < 4096:
raise ValueError("index out of range [0,4095]")
return IDX_TO_CHAR[k // 64], IDX_TO_CHAR[k % 64]
def encode_message(msg: str, C: np.ndarray) -> np.ndarray:
"""
Encode a 40-character message. Each 2-character pair one codeword row.
After concatenation the whole signal is duplicated sample-wise so that
the channels even / odd indices each carry one clean copy.
"""
if len(msg) != TEXT_LEN:
if len(msg) != TEXT_LENGTH:
raise ValueError("Message must be exactly 40 characters.")
pairs = [(msg[i], msg[i+1]) for i in range(0, TEXT_LEN, 2)]
rows = [C[pair_to_index(a, b)] for a, b in pairs] # 20×4096
signal = np.concatenate(rows).astype(np.float32)
signal = np.repeat(signal, 2) # duplicate
# tight numeric safety-check (≡ 2000, barring float error)
e = np.sum(signal**2)
if not np.isclose(e, ENERGY_LIMIT, atol=1e-3):
raise RuntimeError(f"energy sanity check failed ({e:.3f} ≠ 2000)")
idxs = [CHAR_TO_IDX[c] for c in msg]
pair_idxs = [(idxs[i] << 6) | idxs[i+1] for i in range(0, TEXT_LENGTH, 2)]
signal = np.repeat(C[pair_idxs].ravel(), 2).astype(np.float32)
if not np.isclose(signal.dot(signal), ENERGY_LIMIT, atol=1e-3):
raise RuntimeError("energy check failed")
return signal
if __name__ == "__main__":
if len(sys.argv) > 1:
msg = sys.argv[1]
else:
msg = input(f"Enter a message ({TEXT_LENGTH} characters): ").strip()
if len(msg) != TEXT_LENGTH:
print(f"Message must be exactly {TEXT_LENGTH} characters.")
sys.exit(1)
C, _ = make_codebook()
signal = encode_message(msg, C)
np.savetxt("input.txt", signal)
print("Signal written to input.txt")

View file

@ -1,57 +0,0 @@
# encoder_backup.py
import numpy as np
# System parameters
G = 10.0 # power gain for even samples
ENERGY_LIMIT = 2000.0 # total energy per block
ALPHABET = (
[chr(i) for i in range(ord('a'), ord('z')+1)] +
[chr(i) for i in range(ord('A'), ord('Z')+1)] +
[str(i) for i in range(10)] +
[' ', '.']
)
assert len(ALPHABET) == 64, "Alphabet must be size 64"
def hadamard(r: int) -> np.ndarray:
if r == 0:
return np.array([[1]], dtype=float)
M = hadamard(r-1)
return np.block([[M, M], [M, -M]])
def Br(r: int) -> np.ndarray:
M = hadamard(r)
return np.vstack([M, -M])
def make_codebook(r: int, num_blocks: int) -> np.ndarray:
"""
Build 64x64 codebook and scale blocks so energy per block ENERGY_LIMIT/num_blocks
"""
B_full = Br(r) #
B = B_full[: len(ALPHABET), :] # now shape (64, 2^r)
C = np.hstack([B, B]).astype(float) # shape (64, 2^{r+1})
raw_norm = np.sum(C[0]**2)
margin = 0.99
alpha = margin * (ENERGY_LIMIT / num_blocks) / raw_norm
return np.sqrt(alpha) * C
def interleave(c: np.ndarray) -> np.ndarray:
half = c.size // 2
x = np.empty(c.size)
x[0::2] = c[:half]
x[1::2] = c[half:]
return x
def encode_message(msg: str, C: np.ndarray) -> np.ndarray:
"""
Encode 40-character message into interleaved code symbols
"""
if len(msg) != 40:
raise ValueError("Message must be exactly 40 characters.")
idx = [ALPHABET.index(ch) for ch in msg]
blocks = [interleave(C[i]) for i in idx]
return np.concatenate(blocks)

75
main.py
View file

@ -1,65 +1,70 @@
#!/usr/bin/env python3
"""
unchanged CLI only the *r* and the assert on message length moved to encoder
"""
import argparse, sys, numpy as np, subprocess, pathlib, os, tempfile
import argparse
import numpy as np
import encoder, decoder, channel
import subprocess, pathlib, os, tempfile, sys
INPUT_FILE = OUTPUT_FILE = None
def transmit(msg, C): return encoder.encode_message(msg, C)
def receive_local(x): return channel.channel(x)
def transmit(msg, C):
return encoder.encode_message(msg, C)
def receive_server(x, host, port):
global INPUT_FILE, OUTPUT_FILE
if INPUT_FILE:
in_f, rm_in = INPUT_FILE, False
def receive_local(x):
return channel.channel(x)
def receive_server(x, host, port, input_file=None, output_file=None):
in_f = input_file or tempfile.NamedTemporaryFile(suffix='.txt', delete=False).name
np.savetxt(in_f, x)
else:
tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
np.savetxt(tf.name, x); tf.close()
in_f, rm_in = tf.name, True
if OUTPUT_FILE:
out_f, rm_out = OUTPUT_FILE, False
else:
fd, out_f = tempfile.mkstemp(suffix='.txt'); os.close(fd); rm_out = True
cmd = [sys.executable, str(pathlib.Path(__file__).parent/'client.py'),
out_f = output_file or tempfile.mkstemp(suffix='.txt')[1]
cmd = [
sys.executable, str(pathlib.Path(__file__).parent / 'client.py'),
'--input_file', in_f, '--output_file', out_f,
'--srv_hostname', host, '--srv_port', str(port)]
'--srv_hostname', host, '--srv_port', str(port)
]
try:
subprocess.run(cmd, check=True)
Y = np.loadtxt(out_f)
finally:
if rm_in and os.path.exists(in_f): os.remove(in_f)
if rm_out and os.path.exists(out_f): os.remove(out_f)
if not input_file and os.path.exists(in_f):
os.remove(in_f)
if not output_file and os.path.exists(out_f):
os.remove(out_f)
return Y
def receive(x, mode, host, port):
return receive_local(x) if mode=='local' else receive_server(x,host,port)
def test(msg, n_trials, mode, host, port):
C, _ = encoder.make_codebook() # r=11 by default
def receive(x, mode, host, port, input_file=None, output_file=None):
if mode == 'local':
return receive_local(x)
return receive_server(x, host, port, input_file, output_file)
def test(msg, n_trials, mode, host, port, input_file=None, output_file=None):
C, _ = encoder.make_codebook()
print(f"Using codebook with {C.shape[0]} codewords, {C.shape[1]} symbols each.")
ok = 0
for _ in range(n_trials):
x = transmit(msg, C)
print(f"Transmitted {len(x):,} samples (energy={np.dot(x, x):.2f})")
y = receive(x, mode, host, port)
y = receive(x, mode, host, port, input_file, output_file)
print(f"Received {len(y):,} samples (energy={np.dot(y, y):.2f})")
est, _ = decoder.decode_blocks_with_state(y, C)
if est == msg: ok += 1
if est == msg:
ok += 1
print(f"{ok}/{n_trials} perfect decodes ({100 * ok / n_trials:.2f}%)")
def _args():
def parse_args():
p = argparse.ArgumentParser()
p.add_argument('-m', '--message', required=True, help='exactly 40 chars')
p.add_argument('-n', '--trials', type=int, default=200)
p.add_argument('--mode', choices=['local', 'server'], default='local')
p.add_argument('--hostname',default='iscsrv72.epfl.ch'); p.add_argument('--port',type=int,default=80)
p.add_argument('-i','--input_file'); p.add_argument('-o','--output_file')
p.add_argument('--hostname', default='iscsrv72.epfl.ch')
p.add_argument('--port', type=int, default=80)
p.add_argument('-i', '--input_file')
p.add_argument('-o', '--output_file')
return p.parse_args()
if __name__ == '__main__':
a=_args(); INPUT_FILE=a.input_file; OUTPUT_FILE=a.output_file
test(a.message, a.trials, a.mode, a.hostname, a.port)
a = parse_args()
test(a.message, a.trials, a.mode, a.hostname, a.port, a.input_file, a.output_file)

View file

@ -1,119 +0,0 @@
#!/usr/bin/env python3
import argparse
import sys
import numpy as np
import encoder
import decoder
import channel
import subprocess
import pathlib
import os
import tempfile
# Global paths for debugging
INPUT_FILE = None
OUTPUT_FILE = None
def transmit(msg, C):
return encoder.encode_message(msg, C)
def receive_local(c):
return channel.channel(c)
def receive_server(c, hostname, port):
global INPUT_FILE, OUTPUT_FILE
# write or temp-file for input
if INPUT_FILE:
in_name, delete_in = INPUT_FILE, False
np.savetxt(in_name, c)
else:
tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
np.savetxt(tf.name, c)
in_name, delete_in = tf.name, True
tf.close()
# write or temp-file for output
if OUTPUT_FILE:
out_name, delete_out = OUTPUT_FILE, False
else:
fd, out_name = tempfile.mkstemp(suffix='.txt')
os.close(fd)
delete_out = True
# call client.py
cmd = [
sys.executable,
str(pathlib.Path(__file__).parent / 'client.py'),
'--input_file', in_name,
'--output_file', out_name,
'--srv_hostname', hostname,
'--srv_port', str(port)
]
try:
subprocess.run(cmd, check=True)
Y = np.loadtxt(out_name)
finally:
if delete_in and os.path.exists(in_name): os.remove(in_name)
if delete_out and os.path.exists(out_name): os.remove(out_name)
return Y
def receive(c, mode, hostname, port):
if mode == 'local':
return receive_local(c)
elif mode == 'server':
return receive_server(c, hostname, port)
else:
raise ValueError("mode must be 'local' or 'server'")
def test_performance(msg, num_trials, mode, hostname, port):
if len(msg) != 40:
raise ValueError('Message must be exactly 40 characters.')
# build the same codebook
C = encoder.make_codebook(r=11, num_blocks=40)
successes = 0
print(f"Original message: {msg!r}")
print(f"Running {num_trials} trials over '{mode}' channel...\n")
for i in range(num_trials):
# TX → channel → RX
c = transmit(msg, C)
Y = receive(c, mode, hostname, port)
# decode with stateestimation
est, s_est = decoder.decode_blocks_with_state(Y, C)
# count char errors
errors = sum(1 for a,b in zip(est, msg) if a!=b)
if errors == 0:
successes += 1
print(f"Trial {i+1:3d}: state={s_est} decoded={est!r} errors={errors}")
pct = 100 * successes / num_trials
print("\n=== Summary ===")
print(f" Total trials: {num_trials}")
print(f" Perfect decodings: {successes}")
print(f" Overall accuracy: {pct:.2f}%")
def parse_args():
p = argparse.ArgumentParser(description='Test comms system')
p.add_argument('--message','-m', required=True, help='40-char message')
p.add_argument('--trials','-n', type=int, default=200, help='Number of trials')
p.add_argument('--mode', choices=['local','server'], default='local')
p.add_argument('--hostname', default='iscsrv72.epfl.ch')
p.add_argument('--port', type=int, default=80)
p.add_argument('--input_file','-i', default=None,
help='(server mode) where to write input.txt')
p.add_argument('--output_file','-o',default=None,
help='(server mode) where to write output.txt')
return p.parse_args()
if __name__=='__main__':
args = parse_args()
INPUT_FILE = args.input_file
OUTPUT_FILE = args.output_file
test_performance(args.message, args.trials,
args.mode, args.hostname, args.port)