feat: good code

This commit is contained in:
appellet 2025-05-30 00:44:20 +02:00
parent 8b876c0d06
commit 80da1ea764
5 changed files with 255 additions and 198 deletions

View file

@ -34,7 +34,7 @@ All commands assume you are in the project root directory.
### 1. Test locally for 1 trial ### 1. Test locally for 1 trial
```bash ```bash
python3 main.py \ python3 main_backup.py \
--message "Lorem ipsum dolor sit amet. consectetuer" \ --message "Lorem ipsum dolor sit amet. consectetuer" \
--trials 1 \ --trials 1 \
--mode local --mode local
@ -43,7 +43,7 @@ python3 main.py \
### 2. Test locally for 500 trials ### 2. Test locally for 500 trials
```bash ```bash
python3 main.py \ python3 main_backup.py \
--message "Lorem ipsum dolor sit amet. consectetuer" \ --message "Lorem ipsum dolor sit amet. consectetuer" \
--trials 500 \ --trials 500 \
--mode local --mode local
@ -54,7 +54,7 @@ python3 main.py \
This will write `input.txt` and `output.txt` in your working directory. This will write `input.txt` and `output.txt` in your working directory.
```bash ```bash
python3 main.py \ python3 main_backup.py \
--message "Lorem ipsum dolor sit amet. consectetuer" \ --message "Lorem ipsum dolor sit amet. consectetuer" \
--trials 1 \ --trials 1 \
--mode server \ --mode server \

View file

@ -1,61 +1,43 @@
# decoder.py
import numpy as np import numpy as np
from encoder import ALPHABET, G
from encoder import (
G, CHAR_SET, pair_to_index, index_to_pair, make_codebook
)
print(np.__config__.show())
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray):
n_cw = C.shape[1] * 2 # 8192 samples after duplication
nb = Y.size // n_cw
assert nb * n_cw == Y.size, "length mismatch"
sqrtg = np.sqrt(G)
# --- Vectorize even/odd block extraction done for speedup because nupy is dumb
Y_blocks = Y.reshape(nb, n_cw)
Y_even = Y_blocks[:, ::2] # shape (nb, 4096)
Y_odd = Y_blocks[:, 1::2] # shape (nb, 4096)
#print(f"Extracted {nb} blocks of {n_cw} samples each, ")
# --- Vectorize all block scoring
# Each: (nb, 4096) = (nb, 4096) @ (4096, 4096).T
#print(Y_blocks.dtype, C.dtype, sqrtg.dtype, Y_even.dtype, Y_odd.dtype)
Y_even = Y_even.astype(np.float32, copy=False)
Y_odd = Y_odd.astype(np.float32, copy=False)
C = C.astype(np.float32, copy=False)
s1 = sqrtg * (Y_even @ C.T) + (Y_odd @ C.T)
s2 = (Y_even @ C.T) + sqrtg * (Y_odd @ C.T)
#print(f"Scoring {nb} blocks with {C.shape[0]} codewords, each of length {C.shape[1]}.")
best_if_s1 = np.argmax(s1, axis=1)
best_if_s2 = np.argmax(s2, axis=1)
tot1 = np.sum(np.max(s1, axis=1))
tot2 = np.sum(np.max(s2, axis=1))
state = 1 if tot1 >= tot2 else 2
indices = best_if_s1 if state == 1 else best_if_s2
chars = [index_to_pair(i) for i in indices]
decoded = ''.join(a+b for a, b in chars)
return decoded, state
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str: def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
""" return decode_blocks_with_state(Y, C)[0]
Decode received samples by maximum correlation score (state unknown, per-block).
"""
n = C.shape[1]
half = n // 2
num = Y.size // n
C1 = C[:, :half]
C2 = C[:, half:]
sqrtG = np.sqrt(G)
recovered = []
for k in range(num):
Yb = Y[k * n:(k + 1) * n]
Ye, Yo = Yb[0::2], Yb[1::2]
s1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
s2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
score = np.maximum(s1, s2)
best = int(np.argmax(score))
recovered.append(ALPHABET[best])
return "".join(recovered)
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray) -> (str, int):
"""
1) Estimate the single channel state s{1,2} by comparing total energy
on even vs odd positions across the entire Y.
2) Decode **all** blocks using that one states scoring rule.
Returns (decoded_string, estimated_state).
"""
n = C.shape[1]
half = n // 2
num = Y.size // n
C1 = C[:, :half]
C2 = C[:, half:]
sqrtG = np.sqrt(G)
# 1) state estimate from full-length energies
Ye_all = Y[0::2]
Yo_all = Y[1::2]
E_even = np.sum(Ye_all ** 2)
E_odd = np.sum(Yo_all ** 2)
s_est = 1 if E_even > E_odd else 2
recovered = []
for k in range(num):
Yb = Y[k * n:(k + 1) * n]
Ye, Yo = Yb[0::2], Yb[1::2]
if s_est == 1:
score = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
else:
score = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
best = int(np.argmax(score))
recovered.append(ALPHABET[best])
return "".join(recovered), s_est

103
decoder_backup.py Normal file
View file

@ -0,0 +1,103 @@
# decoder_backup.py
import numpy as np
from encoder_backup import ALPHABET, G
# Match the channels noise variance
SIGMA2 = 10.0
def _block_ll_scores(Yb: np.ndarray,
C1: np.ndarray,
C2: np.ndarray,
sqrtG: float
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute per-symbol log-likelihood scores for one interleaved block Yb
under the two channel states (even-boosted vs. odd-boosted).
Returns (scores_state1, scores_state2).
"""
# Split received block into even/odd samples
Ye, Yo = Yb[0::2], Yb[1::2]
# Precompute the squared-norm penalties for each codeword half
# (these come from the -||Y - H_s C||^2 term)
# state1: even half is √G * C1, odd half is C2
E1 = G * np.sum(C1**2, axis=1) + np.sum(C2**2, axis=1)
# state2: even half is C1, odd half is √G * C2
E2 = np.sum(C1**2, axis=1) + G * np.sum(C2**2, axis=1)
# Correlation terms
corr1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
corr2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
# ML loglikelihood (up to constant 1/(2σ²)):
scores1 = (corr1 - 0.5 * E1) / SIGMA2
scores2 = (corr2 - 0.5 * E2) / SIGMA2
return scores1, scores2
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
"""
Per-block ML decoding marginalizing over the unknown state:
for each block, compute scores1/2 via _block_ll_scores, then
marginal_score = logaddexp(scores1, scores2) and pick argmax.
"""
n = C.shape[1]
assert Y.size % n == 0, "Y length must be a multiple of codeword length"
num_blocks = Y.size // n
half = n // 2
C1, C2 = C[:, :half], C[:, half:]
sqrtG = np.sqrt(G)
recovered = []
for k in range(num_blocks):
Yb = Y[k*n:(k+1)*n]
s1, s2 = _block_ll_scores(Yb, C1, C2, sqrtG)
# marginal log-likelihood per symbol
marg = np.logaddexp(s1, s2)
best = int(np.argmax(marg))
recovered.append(ALPHABET[best])
return "".join(recovered)
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray) -> (str, int):
"""
JointML state estimation and decoding:
- For each block, get per-state scores via _block_ll_scores
- Pick the best symbol index under each state; sum those logLs across blocks
- Choose the state with the higher total logL
- Reconstruct the string using the bestsymbol indices for that state
Returns (decoded_string, estimated_state).
"""
n = C.shape[1]
assert Y.size % n == 0, "Y length must be a multiple of codeword length"
num_blocks = Y.size // n
half = n // 2
C1, C2 = C[:, :half], C[:, half:]
sqrtG = np.sqrt(G)
total1, total2 = 0.0, 0.0
best1, best2 = [], []
for k in range(num_blocks):
Yb = Y[k*n:(k+1)*n]
s1, s2 = _block_ll_scores(Yb, C1, C2, sqrtG)
idx1 = int(np.argmax(s1))
idx2 = int(np.argmax(s2))
total1 += s1[idx1]
total2 += s2[idx2]
best1.append(idx1)
best2.append(idx2)
s_est = 1 if total1 >= total2 else 2
chosen = best1 if s_est == 1 else best2
decoded = "".join(ALPHABET[i] for i in chosen)
return decoded, s_est

View file

@ -1,56 +1,82 @@
# encoder.py # encoder.py — 2-char/1-codeword implementation
import numpy as np import numpy as np
from typing import Tuple
# System parameters ##############################################################################
G = 10.0 # power gain for even samples # Public constants
ENERGY_LIMIT = 2000.0 # total energy per block ##############################################################################
ALPHABET = ( CHAR_SET = (
[chr(i) for i in range(ord('a'), ord('z')+1)] + [chr(i) for i in range(ord('a'), ord('z')+1)] +
[chr(i) for i in range(ord('A'), ord('Z')+1)] + [chr(i) for i in range(ord('A'), ord('Z')+1)] +
[str(i) for i in range(10)] + [str(i) for i in range(10)] +
[' ', '.'] [' ', '.']
) )
assert len(ALPHABET) == 64, "Alphabet must be size 64" assert len(CHAR_SET) == 64
CHAR_TO_IDX = {c: i for i, c in enumerate(CHAR_SET)}
IDX_TO_CHAR = {i: c for c, i in CHAR_TO_IDX.items()}
G = 10.0 # channel gain
ENERGY_LIMIT = 2000.0 # global limit ‖x‖²
TEXT_LEN = 40 # must stay 40
def hadamard(r: int) -> np.ndarray: ##############################################################################
# Hadamard-codebook utilities
##############################################################################
def _hadamard(r: int) -> np.ndarray:
if r == 0: if r == 0:
return np.array([[1]], dtype=float) return np.array([[1.]], dtype=np.float32)
M = hadamard(r-1) H = _hadamard(r-1)
return np.block([[M, M], [M, -M]]) return np.block([[H, H],
[H, -H]])
def _Br(r: int) -> np.ndarray:
H = _hadamard(r)
return np.vstack([H, -H]) # 2^(r+1) × 2^r
def Br(r: int) -> np.ndarray: ##############################################################################
M = hadamard(r) # Public API
return np.vstack([M, -M]) ##############################################################################
def make_codebook(r: int = 11,
num_blocks: int = TEXT_LEN//2,
def make_codebook(r: int, num_blocks: int) -> np.ndarray: energy_budget: float = ENERGY_LIMIT
) -> Tuple[np.ndarray, float]:
""" """
Build 64x64 codebook and scale blocks so energy per block ENERGY_LIMIT/num_blocks Builds the scaled codebook C (4096×4096) used by both encoder & decoder.
α is chosen so that **after the per-sample duplication** in encode_message,
a full 20-block message consumes exactly `energy_budget`.
""" """
B = Br(r) B = _Br(r) # 4096 × 2048
C = np.hstack([B, B]).astype(float) C = np.hstack((B, B)).astype(np.float32) # 4096 × 4096
raw_norm = np.sum(C[0]**2) n = C.shape[1] # 4096
margin = 0.95 dup_factor = 2 # sample-duplication
alpha = margin * (ENERGY_LIMIT / num_blocks) / raw_norm alpha = energy_budget / (num_blocks * dup_factor * n)
return np.sqrt(alpha) * C C *= np.sqrt(alpha, dtype=C.dtype)
return C, alpha
def pair_to_index(a: str, b: str) -> int:
return 64 * CHAR_TO_IDX[a] + CHAR_TO_IDX[b]
def interleave(c: np.ndarray) -> np.ndarray: def index_to_pair(k: int) -> tuple[str, str]:
half = c.size // 2 if not 0 <= k < 4096:
x = np.empty(c.size) raise ValueError("index out of range [0,4095]")
x[0::2] = c[:half] return IDX_TO_CHAR[k // 64], IDX_TO_CHAR[k % 64]
x[1::2] = c[half:]
return x
def encode_message(msg: str, C: np.ndarray) -> np.ndarray: def encode_message(msg: str, C: np.ndarray) -> np.ndarray:
""" """
Encode 40-character message into interleaved code symbols Encode a 40-character message. Each 2-character pair one codeword row.
After concatenation the whole signal is duplicated sample-wise so that
the channels even / odd indices each carry one clean copy.
""" """
if len(msg) != 40: if len(msg) != TEXT_LEN:
raise ValueError("Message must be exactly 40 characters.") raise ValueError("Message must be exactly 40 characters.")
idx = [ALPHABET.index(ch) for ch in msg]
blocks = [interleave(C[i]) for i in idx] pairs = [(msg[i], msg[i+1]) for i in range(0, TEXT_LEN, 2)]
return np.concatenate(blocks) rows = [C[pair_to_index(a, b)] for a, b in pairs] # 20×4096
signal = np.concatenate(rows).astype(np.float32)
signal = np.repeat(signal, 2) # duplicate
# tight numeric safety-check (≡ 2000, barring float error)
e = np.sum(signal**2)
if not np.isclose(e, ENERGY_LIMIT, atol=1e-3):
raise RuntimeError(f"energy sanity check failed ({e:.3f} ≠ 2000)")
return signal

140
main.py
View file

@ -1,119 +1,65 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
"""
unchanged CLI only the *r* and the assert on message length moved to encoder
"""
import argparse, sys, numpy as np, subprocess, pathlib, os, tempfile
import encoder, decoder, channel
import argparse INPUT_FILE = OUTPUT_FILE = None
import sys
import numpy as np
import encoder
import decoder
import channel
import subprocess
import pathlib
import os
import tempfile
# Global paths for debugging def transmit(msg, C): return encoder.encode_message(msg, C)
INPUT_FILE = None def receive_local(x): return channel.channel(x)
OUTPUT_FILE = None
def transmit(msg, C): def receive_server(x, host, port):
return encoder.encode_message(msg, C)
def receive_local(c):
return channel.channel(c)
def receive_server(c, hostname, port):
global INPUT_FILE, OUTPUT_FILE global INPUT_FILE, OUTPUT_FILE
# write or temp-file for input
if INPUT_FILE: if INPUT_FILE:
in_name, delete_in = INPUT_FILE, False in_f, rm_in = INPUT_FILE, False
np.savetxt(in_name, c) np.savetxt(in_f, x)
else: else:
tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False) tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
np.savetxt(tf.name, c) np.savetxt(tf.name, x); tf.close()
in_name, delete_in = tf.name, True in_f, rm_in = tf.name, True
tf.close()
# write or temp-file for output
if OUTPUT_FILE: if OUTPUT_FILE:
out_name, delete_out = OUTPUT_FILE, False out_f, rm_out = OUTPUT_FILE, False
else: else:
fd, out_name = tempfile.mkstemp(suffix='.txt') fd, out_f = tempfile.mkstemp(suffix='.txt'); os.close(fd); rm_out = True
os.close(fd)
delete_out = True
# call client.py cmd = [sys.executable, str(pathlib.Path(__file__).parent/'client.py'),
cmd = [ '--input_file', in_f, '--output_file', out_f,
sys.executable, '--srv_hostname', host, '--srv_port', str(port)]
str(pathlib.Path(__file__).parent / 'client.py'),
'--input_file', in_name,
'--output_file', out_name,
'--srv_hostname', hostname,
'--srv_port', str(port)
]
try: try:
subprocess.run(cmd, check=True) subprocess.run(cmd, check=True)
Y = np.loadtxt(out_name) Y = np.loadtxt(out_f)
finally: finally:
if delete_in and os.path.exists(in_name): os.remove(in_name) if rm_in and os.path.exists(in_f): os.remove(in_f)
if delete_out and os.path.exists(out_name): os.remove(out_name) if rm_out and os.path.exists(out_f): os.remove(out_f)
return Y return Y
def receive(c, mode, hostname, port): def receive(x, mode, host, port):
if mode == 'local': return receive_local(x) if mode=='local' else receive_server(x,host,port)
return receive_local(c)
elif mode == 'server':
return receive_server(c, hostname, port)
else:
raise ValueError("mode must be 'local' or 'server'")
def test_performance(msg, num_trials, mode, hostname, port): def test(msg, n_trials, mode, host, port):
if len(msg) != 40: C, _ = encoder.make_codebook() # r=11 by default
raise ValueError('Message must be exactly 40 characters.') print(f"Using codebook with {C.shape[0]} codewords, {C.shape[1]} symbols each.")
ok = 0
for _ in range(n_trials):
x = transmit(msg, C)
print(f"Transmitted {len(x):,} samples (energy={np.dot(x,x):.2f})")
y = receive(x, mode, host, port)
print(f"Received {len(y):,} samples (energy={np.dot(y,y):.2f})")
est, _ = decoder.decode_blocks_with_state(y, C)
if est == msg: ok += 1
print(f"{ok}/{n_trials} perfect decodes ({100*ok/n_trials:.2f}%)")
# build the same codebook def _args():
C = encoder.make_codebook(r=5, num_blocks=40) p=argparse.ArgumentParser()
p.add_argument('-m','--message',required=True,help='exactly 40 chars')
successes = 0 p.add_argument('-n','--trials', type=int, default=200)
print(f"Original message: {msg!r}")
print(f"Running {num_trials} trials over '{mode}' channel...\n")
for i in range(num_trials):
# TX → channel → RX
c = transmit(msg, C)
Y = receive(c, mode, hostname, port)
# decode with stateestimation
est, s_est = decoder.decode_blocks_with_state(Y, C)
# count char errors
errors = sum(1 for a,b in zip(est, msg) if a!=b)
if errors == 0:
successes += 1
print(f"Trial {i+1:3d}: state={s_est} decoded={est!r} errors={errors}")
pct = 100 * successes / num_trials
print("\n=== Summary ===")
print(f" Total trials: {num_trials}")
print(f" Perfect decodings: {successes}")
print(f" Overall accuracy: {pct:.2f}%")
def parse_args():
p = argparse.ArgumentParser(description='Test comms system')
p.add_argument('--message','-m', required=True, help='40-char message')
p.add_argument('--trials','-n', type=int, default=200, help='Number of trials')
p.add_argument('--mode',choices=['local','server'],default='local') p.add_argument('--mode',choices=['local','server'],default='local')
p.add_argument('--hostname', default='iscsrv72.epfl.ch') p.add_argument('--hostname',default='iscsrv72.epfl.ch'); p.add_argument('--port',type=int,default=80)
p.add_argument('--port', type=int, default=80) p.add_argument('-i','--input_file'); p.add_argument('-o','--output_file')
p.add_argument('--input_file','-i', default=None,
help='(server mode) where to write input.txt')
p.add_argument('--output_file','-o',default=None,
help='(server mode) where to write output.txt')
return p.parse_args() return p.parse_args()
if __name__=='__main__': if __name__=='__main__':
args = parse_args() a=_args(); INPUT_FILE=a.input_file; OUTPUT_FILE=a.output_file
INPUT_FILE = args.input_file test(a.message, a.trials, a.mode, a.hostname, a.port)
OUTPUT_FILE = args.output_file
test_performance(args.message, args.trials,
args.mode, args.hostname, args.port)