feat: good code

This commit is contained in:
appellet 2025-05-30 00:44:20 +02:00
parent 8b876c0d06
commit 80da1ea764
5 changed files with 255 additions and 198 deletions

View file

@ -34,7 +34,7 @@ All commands assume you are in the project root directory.
### 1. Test locally for 1 trial
```bash
python3 main.py \
python3 main_backup.py \
--message "Lorem ipsum dolor sit amet. consectetuer" \
--trials 1 \
--mode local
@ -43,7 +43,7 @@ python3 main.py \
### 2. Test locally for 500 trials
```bash
python3 main.py \
python3 main_backup.py \
--message "Lorem ipsum dolor sit amet. consectetuer" \
--trials 500 \
--mode local
@ -54,7 +54,7 @@ python3 main.py \
This will write `input.txt` and `output.txt` in your working directory.
```bash
python3 main.py \
python3 main_backup.py \
--message "Lorem ipsum dolor sit amet. consectetuer" \
--trials 1 \
--mode server \

View file

@ -1,61 +1,43 @@
# decoder.py
import numpy as np
from encoder import ALPHABET, G
from encoder import (
G, CHAR_SET, pair_to_index, index_to_pair, make_codebook
)
print(np.__config__.show())
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray):
n_cw = C.shape[1] * 2 # 8192 samples after duplication
nb = Y.size // n_cw
assert nb * n_cw == Y.size, "length mismatch"
sqrtg = np.sqrt(G)
# --- Vectorize even/odd block extraction done for speedup because nupy is dumb
Y_blocks = Y.reshape(nb, n_cw)
Y_even = Y_blocks[:, ::2] # shape (nb, 4096)
Y_odd = Y_blocks[:, 1::2] # shape (nb, 4096)
#print(f"Extracted {nb} blocks of {n_cw} samples each, ")
# --- Vectorize all block scoring
# Each: (nb, 4096) = (nb, 4096) @ (4096, 4096).T
#print(Y_blocks.dtype, C.dtype, sqrtg.dtype, Y_even.dtype, Y_odd.dtype)
Y_even = Y_even.astype(np.float32, copy=False)
Y_odd = Y_odd.astype(np.float32, copy=False)
C = C.astype(np.float32, copy=False)
s1 = sqrtg * (Y_even @ C.T) + (Y_odd @ C.T)
s2 = (Y_even @ C.T) + sqrtg * (Y_odd @ C.T)
#print(f"Scoring {nb} blocks with {C.shape[0]} codewords, each of length {C.shape[1]}.")
best_if_s1 = np.argmax(s1, axis=1)
best_if_s2 = np.argmax(s2, axis=1)
tot1 = np.sum(np.max(s1, axis=1))
tot2 = np.sum(np.max(s2, axis=1))
state = 1 if tot1 >= tot2 else 2
indices = best_if_s1 if state == 1 else best_if_s2
chars = [index_to_pair(i) for i in indices]
decoded = ''.join(a+b for a, b in chars)
return decoded, state
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
"""
Decode received samples by maximum correlation score (state unknown, per-block).
"""
n = C.shape[1]
half = n // 2
num = Y.size // n
C1 = C[:, :half]
C2 = C[:, half:]
sqrtG = np.sqrt(G)
recovered = []
for k in range(num):
Yb = Y[k * n:(k + 1) * n]
Ye, Yo = Yb[0::2], Yb[1::2]
s1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
s2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
score = np.maximum(s1, s2)
best = int(np.argmax(score))
recovered.append(ALPHABET[best])
return "".join(recovered)
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray) -> (str, int):
"""
1) Estimate the single channel state s{1,2} by comparing total energy
on even vs odd positions across the entire Y.
2) Decode **all** blocks using that one states scoring rule.
Returns (decoded_string, estimated_state).
"""
n = C.shape[1]
half = n // 2
num = Y.size // n
C1 = C[:, :half]
C2 = C[:, half:]
sqrtG = np.sqrt(G)
# 1) state estimate from full-length energies
Ye_all = Y[0::2]
Yo_all = Y[1::2]
E_even = np.sum(Ye_all ** 2)
E_odd = np.sum(Yo_all ** 2)
s_est = 1 if E_even > E_odd else 2
recovered = []
for k in range(num):
Yb = Y[k * n:(k + 1) * n]
Ye, Yo = Yb[0::2], Yb[1::2]
if s_est == 1:
score = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
else:
score = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
best = int(np.argmax(score))
recovered.append(ALPHABET[best])
return "".join(recovered), s_est
return decode_blocks_with_state(Y, C)[0]

103
decoder_backup.py Normal file
View file

@ -0,0 +1,103 @@
# decoder_backup.py
import numpy as np
from encoder_backup import ALPHABET, G
# Match the channels noise variance
SIGMA2 = 10.0
def _block_ll_scores(Yb: np.ndarray,
C1: np.ndarray,
C2: np.ndarray,
sqrtG: float
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute per-symbol log-likelihood scores for one interleaved block Yb
under the two channel states (even-boosted vs. odd-boosted).
Returns (scores_state1, scores_state2).
"""
# Split received block into even/odd samples
Ye, Yo = Yb[0::2], Yb[1::2]
# Precompute the squared-norm penalties for each codeword half
# (these come from the -||Y - H_s C||^2 term)
# state1: even half is √G * C1, odd half is C2
E1 = G * np.sum(C1**2, axis=1) + np.sum(C2**2, axis=1)
# state2: even half is C1, odd half is √G * C2
E2 = np.sum(C1**2, axis=1) + G * np.sum(C2**2, axis=1)
# Correlation terms
corr1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
corr2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
# ML loglikelihood (up to constant 1/(2σ²)):
scores1 = (corr1 - 0.5 * E1) / SIGMA2
scores2 = (corr2 - 0.5 * E2) / SIGMA2
return scores1, scores2
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
"""
Per-block ML decoding marginalizing over the unknown state:
for each block, compute scores1/2 via _block_ll_scores, then
marginal_score = logaddexp(scores1, scores2) and pick argmax.
"""
n = C.shape[1]
assert Y.size % n == 0, "Y length must be a multiple of codeword length"
num_blocks = Y.size // n
half = n // 2
C1, C2 = C[:, :half], C[:, half:]
sqrtG = np.sqrt(G)
recovered = []
for k in range(num_blocks):
Yb = Y[k*n:(k+1)*n]
s1, s2 = _block_ll_scores(Yb, C1, C2, sqrtG)
# marginal log-likelihood per symbol
marg = np.logaddexp(s1, s2)
best = int(np.argmax(marg))
recovered.append(ALPHABET[best])
return "".join(recovered)
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray) -> (str, int):
"""
JointML state estimation and decoding:
- For each block, get per-state scores via _block_ll_scores
- Pick the best symbol index under each state; sum those logLs across blocks
- Choose the state with the higher total logL
- Reconstruct the string using the bestsymbol indices for that state
Returns (decoded_string, estimated_state).
"""
n = C.shape[1]
assert Y.size % n == 0, "Y length must be a multiple of codeword length"
num_blocks = Y.size // n
half = n // 2
C1, C2 = C[:, :half], C[:, half:]
sqrtG = np.sqrt(G)
total1, total2 = 0.0, 0.0
best1, best2 = [], []
for k in range(num_blocks):
Yb = Y[k*n:(k+1)*n]
s1, s2 = _block_ll_scores(Yb, C1, C2, sqrtG)
idx1 = int(np.argmax(s1))
idx2 = int(np.argmax(s2))
total1 += s1[idx1]
total2 += s2[idx2]
best1.append(idx1)
best2.append(idx2)
s_est = 1 if total1 >= total2 else 2
chosen = best1 if s_est == 1 else best2
decoded = "".join(ALPHABET[i] for i in chosen)
return decoded, s_est

View file

@ -1,56 +1,82 @@
# encoder.py
# encoder.py — 2-char/1-codeword implementation
import numpy as np
from typing import Tuple
# System parameters
G = 10.0 # power gain for even samples
ENERGY_LIMIT = 2000.0 # total energy per block
ALPHABET = (
[chr(i) for i in range(ord('a'), ord('z')+1)] +
[chr(i) for i in range(ord('A'), ord('Z')+1)] +
[str(i) for i in range(10)] +
[' ', '.']
##############################################################################
# Public constants
##############################################################################
CHAR_SET = (
[chr(i) for i in range(ord('a'), ord('z')+1)] +
[chr(i) for i in range(ord('A'), ord('Z')+1)] +
[str(i) for i in range(10)] +
[' ', '.']
)
assert len(ALPHABET) == 64, "Alphabet must be size 64"
assert len(CHAR_SET) == 64
CHAR_TO_IDX = {c: i for i, c in enumerate(CHAR_SET)}
IDX_TO_CHAR = {i: c for c, i in CHAR_TO_IDX.items()}
G = 10.0 # channel gain
ENERGY_LIMIT = 2000.0 # global limit ‖x‖²
TEXT_LEN = 40 # must stay 40
def hadamard(r: int) -> np.ndarray:
##############################################################################
# Hadamard-codebook utilities
##############################################################################
def _hadamard(r: int) -> np.ndarray:
if r == 0:
return np.array([[1]], dtype=float)
M = hadamard(r-1)
return np.block([[M, M], [M, -M]])
return np.array([[1.]], dtype=np.float32)
H = _hadamard(r-1)
return np.block([[H, H],
[H, -H]])
def _Br(r: int) -> np.ndarray:
H = _hadamard(r)
return np.vstack([H, -H]) # 2^(r+1) × 2^r
def Br(r: int) -> np.ndarray:
M = hadamard(r)
return np.vstack([M, -M])
def make_codebook(r: int, num_blocks: int) -> np.ndarray:
##############################################################################
# Public API
##############################################################################
def make_codebook(r: int = 11,
num_blocks: int = TEXT_LEN//2,
energy_budget: float = ENERGY_LIMIT
) -> Tuple[np.ndarray, float]:
"""
Build 64x64 codebook and scale blocks so energy per block ENERGY_LIMIT/num_blocks
Builds the scaled codebook C (4096×4096) used by both encoder & decoder.
α is chosen so that **after the per-sample duplication** in encode_message,
a full 20-block message consumes exactly `energy_budget`.
"""
B = Br(r)
C = np.hstack([B, B]).astype(float)
raw_norm = np.sum(C[0]**2)
margin = 0.95
alpha = margin * (ENERGY_LIMIT / num_blocks) / raw_norm
return np.sqrt(alpha) * C
B = _Br(r) # 4096 × 2048
C = np.hstack((B, B)).astype(np.float32) # 4096 × 4096
n = C.shape[1] # 4096
dup_factor = 2 # sample-duplication
alpha = energy_budget / (num_blocks * dup_factor * n)
C *= np.sqrt(alpha, dtype=C.dtype)
return C, alpha
def pair_to_index(a: str, b: str) -> int:
return 64 * CHAR_TO_IDX[a] + CHAR_TO_IDX[b]
def interleave(c: np.ndarray) -> np.ndarray:
half = c.size // 2
x = np.empty(c.size)
x[0::2] = c[:half]
x[1::2] = c[half:]
return x
def index_to_pair(k: int) -> tuple[str, str]:
if not 0 <= k < 4096:
raise ValueError("index out of range [0,4095]")
return IDX_TO_CHAR[k // 64], IDX_TO_CHAR[k % 64]
def encode_message(msg: str, C: np.ndarray) -> np.ndarray:
"""
Encode 40-character message into interleaved code symbols
Encode a 40-character message. Each 2-character pair one codeword row.
After concatenation the whole signal is duplicated sample-wise so that
the channels even / odd indices each carry one clean copy.
"""
if len(msg) != 40:
if len(msg) != TEXT_LEN:
raise ValueError("Message must be exactly 40 characters.")
idx = [ALPHABET.index(ch) for ch in msg]
blocks = [interleave(C[i]) for i in idx]
return np.concatenate(blocks)
pairs = [(msg[i], msg[i+1]) for i in range(0, TEXT_LEN, 2)]
rows = [C[pair_to_index(a, b)] for a, b in pairs] # 20×4096
signal = np.concatenate(rows).astype(np.float32)
signal = np.repeat(signal, 2) # duplicate
# tight numeric safety-check (≡ 2000, barring float error)
e = np.sum(signal**2)
if not np.isclose(e, ENERGY_LIMIT, atol=1e-3):
raise RuntimeError(f"energy sanity check failed ({e:.3f} ≠ 2000)")
return signal

144
main.py
View file

@ -1,119 +1,65 @@
#!/usr/bin/env python3
"""
unchanged CLI only the *r* and the assert on message length moved to encoder
"""
import argparse, sys, numpy as np, subprocess, pathlib, os, tempfile
import encoder, decoder, channel
import argparse
import sys
import numpy as np
import encoder
import decoder
import channel
import subprocess
import pathlib
import os
import tempfile
INPUT_FILE = OUTPUT_FILE = None
# Global paths for debugging
INPUT_FILE = None
OUTPUT_FILE = None
def transmit(msg, C): return encoder.encode_message(msg, C)
def receive_local(x): return channel.channel(x)
def transmit(msg, C):
return encoder.encode_message(msg, C)
def receive_local(c):
return channel.channel(c)
def receive_server(c, hostname, port):
def receive_server(x, host, port):
global INPUT_FILE, OUTPUT_FILE
# write or temp-file for input
if INPUT_FILE:
in_name, delete_in = INPUT_FILE, False
np.savetxt(in_name, c)
in_f, rm_in = INPUT_FILE, False
np.savetxt(in_f, x)
else:
tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
np.savetxt(tf.name, c)
in_name, delete_in = tf.name, True
tf.close()
# write or temp-file for output
tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
np.savetxt(tf.name, x); tf.close()
in_f, rm_in = tf.name, True
if OUTPUT_FILE:
out_name, delete_out = OUTPUT_FILE, False
out_f, rm_out = OUTPUT_FILE, False
else:
fd, out_name = tempfile.mkstemp(suffix='.txt')
os.close(fd)
delete_out = True
fd, out_f = tempfile.mkstemp(suffix='.txt'); os.close(fd); rm_out = True
# call client.py
cmd = [
sys.executable,
str(pathlib.Path(__file__).parent / 'client.py'),
'--input_file', in_name,
'--output_file', out_name,
'--srv_hostname', hostname,
'--srv_port', str(port)
]
cmd = [sys.executable, str(pathlib.Path(__file__).parent/'client.py'),
'--input_file', in_f, '--output_file', out_f,
'--srv_hostname', host, '--srv_port', str(port)]
try:
subprocess.run(cmd, check=True)
Y = np.loadtxt(out_name)
Y = np.loadtxt(out_f)
finally:
if delete_in and os.path.exists(in_name): os.remove(in_name)
if delete_out and os.path.exists(out_name): os.remove(out_name)
if rm_in and os.path.exists(in_f): os.remove(in_f)
if rm_out and os.path.exists(out_f): os.remove(out_f)
return Y
def receive(c, mode, hostname, port):
if mode == 'local':
return receive_local(c)
elif mode == 'server':
return receive_server(c, hostname, port)
else:
raise ValueError("mode must be 'local' or 'server'")
def receive(x, mode, host, port):
return receive_local(x) if mode=='local' else receive_server(x,host,port)
def test_performance(msg, num_trials, mode, hostname, port):
if len(msg) != 40:
raise ValueError('Message must be exactly 40 characters.')
def test(msg, n_trials, mode, host, port):
C, _ = encoder.make_codebook() # r=11 by default
print(f"Using codebook with {C.shape[0]} codewords, {C.shape[1]} symbols each.")
ok = 0
for _ in range(n_trials):
x = transmit(msg, C)
print(f"Transmitted {len(x):,} samples (energy={np.dot(x,x):.2f})")
y = receive(x, mode, host, port)
print(f"Received {len(y):,} samples (energy={np.dot(y,y):.2f})")
est, _ = decoder.decode_blocks_with_state(y, C)
if est == msg: ok += 1
print(f"{ok}/{n_trials} perfect decodes ({100*ok/n_trials:.2f}%)")
# build the same codebook
C = encoder.make_codebook(r=5, num_blocks=40)
successes = 0
print(f"Original message: {msg!r}")
print(f"Running {num_trials} trials over '{mode}' channel...\n")
for i in range(num_trials):
# TX → channel → RX
c = transmit(msg, C)
Y = receive(c, mode, hostname, port)
# decode with stateestimation
est, s_est = decoder.decode_blocks_with_state(Y, C)
# count char errors
errors = sum(1 for a,b in zip(est, msg) if a!=b)
if errors == 0:
successes += 1
print(f"Trial {i+1:3d}: state={s_est} decoded={est!r} errors={errors}")
pct = 100 * successes / num_trials
print("\n=== Summary ===")
print(f" Total trials: {num_trials}")
print(f" Perfect decodings: {successes}")
print(f" Overall accuracy: {pct:.2f}%")
def parse_args():
p = argparse.ArgumentParser(description='Test comms system')
p.add_argument('--message','-m', required=True, help='40-char message')
p.add_argument('--trials','-n', type=int, default=200, help='Number of trials')
p.add_argument('--mode', choices=['local','server'], default='local')
p.add_argument('--hostname', default='iscsrv72.epfl.ch')
p.add_argument('--port', type=int, default=80)
p.add_argument('--input_file','-i', default=None,
help='(server mode) where to write input.txt')
p.add_argument('--output_file','-o',default=None,
help='(server mode) where to write output.txt')
def _args():
p=argparse.ArgumentParser()
p.add_argument('-m','--message',required=True,help='exactly 40 chars')
p.add_argument('-n','--trials', type=int, default=200)
p.add_argument('--mode',choices=['local','server'],default='local')
p.add_argument('--hostname',default='iscsrv72.epfl.ch'); p.add_argument('--port',type=int,default=80)
p.add_argument('-i','--input_file'); p.add_argument('-o','--output_file')
return p.parse_args()
if __name__=='__main__':
args = parse_args()
INPUT_FILE = args.input_file
OUTPUT_FILE = args.output_file
test_performance(args.message, args.trials,
args.mode, args.hostname, args.port)
a=_args(); INPUT_FILE=a.input_file; OUTPUT_FILE=a.output_file
test(a.message, a.trials, a.mode, a.hostname, a.port)