feat: improve code
This commit is contained in:
parent
18aca25240
commit
fdb07e1c00
8 changed files with 159 additions and 417 deletions
18
README.md
18
README.md
|
@ -66,6 +66,24 @@ python3 main_backup.py \
|
||||||
|
|
||||||
> After running, `input.txt` contains your transmitted samples, and `output.txt` contains the noisy output from the server.
|
> After running, `input.txt` contains your transmitted samples, and `output.txt` contains the noisy output from the server.
|
||||||
|
|
||||||
|
### 4. Create input.txt
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 encoder.py "message_40_characters"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Create output.txt throught the channel
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 channel.py
|
||||||
|
```
|
||||||
|
### 6. Decode the output.tct
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 decoder.py
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Manual decoding of server output
|
## Manual decoding of server output
|
||||||
|
|
|
@ -18,3 +18,12 @@ def channel(x):
|
||||||
Y[::2] += x_even
|
Y[::2] += x_even
|
||||||
Y[1::2] += x_odd
|
Y[1::2] += x_odd
|
||||||
return Y
|
return Y
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Read input.txt
|
||||||
|
x = np.loadtxt("input.txt")
|
||||||
|
# Pass through channel
|
||||||
|
y = channel(x)
|
||||||
|
# Write output.txt
|
||||||
|
np.savetxt("output.txt", y)
|
||||||
|
print("Channel output written to output.txt")
|
56
decoder.py
56
decoder.py
|
@ -1,43 +1,49 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from encoder import G, ALPHABET, make_codebook
|
||||||
from encoder import (
|
|
||||||
G, CHAR_SET, pair_to_index, index_to_pair, make_codebook
|
|
||||||
)
|
|
||||||
print(np.__config__.show())
|
|
||||||
|
|
||||||
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray):
|
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray):
|
||||||
n_cw = C.shape[1] * 2 # 8192 samples after duplication
|
N = Y.size
|
||||||
nb = Y.size // n_cw
|
assert N % 2 == 0
|
||||||
assert nb * n_cw == Y.size, "length mismatch"
|
assert N <= 1_000_000
|
||||||
|
|
||||||
|
n_cw = C.shape[1] * 2
|
||||||
|
nb = Y.size // n_cw
|
||||||
|
assert nb * n_cw == Y.size, "length mismatch between Y and codebook"
|
||||||
|
|
||||||
|
Y_blocks = Y.reshape(nb, n_cw)
|
||||||
|
Y_even = Y_blocks[:, ::2].astype(np.float32)
|
||||||
|
Y_odd = Y_blocks[:, 1::2].astype(np.float32)
|
||||||
|
C = C.astype(np.float32)
|
||||||
|
|
||||||
sqrtg = np.sqrt(G)
|
sqrtg = np.sqrt(G)
|
||||||
|
|
||||||
# --- Vectorize even/odd block extraction done for speedup because nupy is dumb
|
|
||||||
Y_blocks = Y.reshape(nb, n_cw)
|
|
||||||
Y_even = Y_blocks[:, ::2] # shape (nb, 4096)
|
|
||||||
Y_odd = Y_blocks[:, 1::2] # shape (nb, 4096)
|
|
||||||
#print(f"Extracted {nb} blocks of {n_cw} samples each, ")
|
|
||||||
|
|
||||||
# --- Vectorize all block scoring
|
|
||||||
# Each: (nb, 4096) = (nb, 4096) @ (4096, 4096).T
|
|
||||||
#print(Y_blocks.dtype, C.dtype, sqrtg.dtype, Y_even.dtype, Y_odd.dtype)
|
|
||||||
|
|
||||||
Y_even = Y_even.astype(np.float32, copy=False)
|
|
||||||
Y_odd = Y_odd.astype(np.float32, copy=False)
|
|
||||||
C = C.astype(np.float32, copy=False)
|
|
||||||
s1 = sqrtg * (Y_even @ C.T) + (Y_odd @ C.T)
|
s1 = sqrtg * (Y_even @ C.T) + (Y_odd @ C.T)
|
||||||
s2 = (Y_even @ C.T) + sqrtg * (Y_odd @ C.T)
|
s2 = (Y_even @ C.T) + sqrtg * (Y_odd @ C.T)
|
||||||
#print(f"Scoring {nb} blocks with {C.shape[0]} codewords, each of length {C.shape[1]}.")
|
|
||||||
best_if_s1 = np.argmax(s1, axis=1)
|
best_if_s1 = np.argmax(s1, axis=1)
|
||||||
best_if_s2 = np.argmax(s2, axis=1)
|
best_if_s2 = np.argmax(s2, axis=1)
|
||||||
|
|
||||||
tot1 = np.sum(np.max(s1, axis=1))
|
tot1 = np.sum(np.max(s1, axis=1))
|
||||||
tot2 = np.sum(np.max(s2, axis=1))
|
tot2 = np.sum(np.max(s2, axis=1))
|
||||||
|
|
||||||
state = 1 if tot1 >= tot2 else 2
|
state = 1 if tot1 >= tot2 else 2
|
||||||
indices = best_if_s1 if state == 1 else best_if_s2
|
indices = best_if_s1 if state == 1 else best_if_s2
|
||||||
chars = [index_to_pair(i) for i in indices]
|
|
||||||
decoded = ''.join(a+b for a, b in chars)
|
msg_chars = []
|
||||||
|
for idx in indices:
|
||||||
|
first = ALPHABET[idx >> 6]
|
||||||
|
second = ALPHABET[idx & 0x3F]
|
||||||
|
msg_chars.append(first)
|
||||||
|
msg_chars.append(second)
|
||||||
|
|
||||||
|
decoded = ''.join(msg_chars)
|
||||||
return decoded, state
|
return decoded, state
|
||||||
|
|
||||||
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
|
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
|
||||||
return decode_blocks_with_state(Y, C)[0]
|
return decode_blocks_with_state(Y, C)[0]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Y = np.loadtxt("output.txt")
|
||||||
|
C, _ = make_codebook()
|
||||||
|
decoded, state = decode_blocks_with_state(Y, C)
|
||||||
|
print(f"Decoded message: {decoded}")
|
||||||
|
print(f"Detected channel state: {state}")
|
|
@ -1,103 +0,0 @@
|
||||||
# decoder_backup.py
|
|
||||||
import numpy as np
|
|
||||||
from encoder_backup import ALPHABET, G
|
|
||||||
|
|
||||||
# Match the channel’s noise variance
|
|
||||||
SIGMA2 = 10.0
|
|
||||||
|
|
||||||
def _block_ll_scores(Yb: np.ndarray,
|
|
||||||
C1: np.ndarray,
|
|
||||||
C2: np.ndarray,
|
|
||||||
sqrtG: float
|
|
||||||
) -> tuple[np.ndarray, np.ndarray]:
|
|
||||||
"""
|
|
||||||
Compute per-symbol log-likelihood scores for one interleaved block Yb
|
|
||||||
under the two channel states (even-boosted vs. odd-boosted).
|
|
||||||
Returns (scores_state1, scores_state2).
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Split received block into even/odd samples
|
|
||||||
Ye, Yo = Yb[0::2], Yb[1::2]
|
|
||||||
|
|
||||||
# Precompute the squared-norm penalties for each codeword half
|
|
||||||
# (these come from the -||Y - H_s C||^2 term)
|
|
||||||
# state1: even half is √G * C1, odd half is C2
|
|
||||||
E1 = G * np.sum(C1**2, axis=1) + np.sum(C2**2, axis=1)
|
|
||||||
# state2: even half is C1, odd half is √G * C2
|
|
||||||
E2 = np.sum(C1**2, axis=1) + G * np.sum(C2**2, axis=1)
|
|
||||||
|
|
||||||
# Correlation terms
|
|
||||||
corr1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
|
|
||||||
corr2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
|
|
||||||
|
|
||||||
# ML log‐likelihood (up to constant −1/(2σ²)):
|
|
||||||
scores1 = (corr1 - 0.5 * E1) / SIGMA2
|
|
||||||
scores2 = (corr2 - 0.5 * E2) / SIGMA2
|
|
||||||
|
|
||||||
return scores1, scores2
|
|
||||||
|
|
||||||
|
|
||||||
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
|
|
||||||
"""
|
|
||||||
Per-block ML decoding marginalizing over the unknown state:
|
|
||||||
for each block, compute scores1/2 via _block_ll_scores, then
|
|
||||||
marginal_score = logaddexp(scores1, scores2) and pick argmax.
|
|
||||||
"""
|
|
||||||
n = C.shape[1]
|
|
||||||
assert Y.size % n == 0, "Y length must be a multiple of codeword length"
|
|
||||||
num_blocks = Y.size // n
|
|
||||||
|
|
||||||
half = n // 2
|
|
||||||
C1, C2 = C[:, :half], C[:, half:]
|
|
||||||
sqrtG = np.sqrt(G)
|
|
||||||
|
|
||||||
recovered = []
|
|
||||||
for k in range(num_blocks):
|
|
||||||
Yb = Y[k*n:(k+1)*n]
|
|
||||||
s1, s2 = _block_ll_scores(Yb, C1, C2, sqrtG)
|
|
||||||
# marginal log-likelihood per symbol
|
|
||||||
marg = np.logaddexp(s1, s2)
|
|
||||||
best = int(np.argmax(marg))
|
|
||||||
recovered.append(ALPHABET[best])
|
|
||||||
|
|
||||||
return "".join(recovered)
|
|
||||||
|
|
||||||
|
|
||||||
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray) -> (str, int):
|
|
||||||
"""
|
|
||||||
Joint‐ML state estimation and decoding:
|
|
||||||
- For each block, get per-state scores via _block_ll_scores
|
|
||||||
- Pick the best symbol index under each state; sum those log‐L’s across blocks
|
|
||||||
- Choose the state with the higher total log‐L
|
|
||||||
- Reconstruct the string using the best‐symbol indices for that state
|
|
||||||
Returns (decoded_string, estimated_state).
|
|
||||||
"""
|
|
||||||
n = C.shape[1]
|
|
||||||
assert Y.size % n == 0, "Y length must be a multiple of codeword length"
|
|
||||||
num_blocks = Y.size // n
|
|
||||||
|
|
||||||
half = n // 2
|
|
||||||
C1, C2 = C[:, :half], C[:, half:]
|
|
||||||
sqrtG = np.sqrt(G)
|
|
||||||
|
|
||||||
total1, total2 = 0.0, 0.0
|
|
||||||
best1, best2 = [], []
|
|
||||||
|
|
||||||
for k in range(num_blocks):
|
|
||||||
Yb = Y[k*n:(k+1)*n]
|
|
||||||
s1, s2 = _block_ll_scores(Yb, C1, C2, sqrtG)
|
|
||||||
|
|
||||||
idx1 = int(np.argmax(s1))
|
|
||||||
idx2 = int(np.argmax(s2))
|
|
||||||
|
|
||||||
total1 += s1[idx1]
|
|
||||||
total2 += s2[idx2]
|
|
||||||
|
|
||||||
best1.append(idx1)
|
|
||||||
best2.append(idx2)
|
|
||||||
|
|
||||||
s_est = 1 if total1 >= total2 else 2
|
|
||||||
chosen = best1 if s_est == 1 else best2
|
|
||||||
decoded = "".join(ALPHABET[i] for i in chosen)
|
|
||||||
|
|
||||||
return decoded, s_est
|
|
117
encoder.py
117
encoder.py
|
@ -1,82 +1,65 @@
|
||||||
# encoder.py — 2-char/1-codeword implementation
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Tuple
|
import sys
|
||||||
|
|
||||||
##############################################################################
|
ALPHABET = "abcdefghijklmnopqrstuvwxyz" \
|
||||||
# Public constants
|
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" \
|
||||||
##############################################################################
|
"0123456789 ."
|
||||||
CHAR_SET = (
|
|
||||||
[chr(i) for i in range(ord('a'), ord('z')+1)] +
|
|
||||||
[chr(i) for i in range(ord('A'), ord('Z')+1)] +
|
|
||||||
[str(i) for i in range(10)] +
|
|
||||||
[' ', '.']
|
|
||||||
)
|
|
||||||
assert len(CHAR_SET) == 64
|
|
||||||
CHAR_TO_IDX = {c: i for i, c in enumerate(CHAR_SET)}
|
|
||||||
IDX_TO_CHAR = {i: c for c, i in CHAR_TO_IDX.items()}
|
|
||||||
|
|
||||||
G = 10.0 # channel gain
|
CHAR_TO_IDX = {c: i for i, c in enumerate(ALPHABET)}
|
||||||
ENERGY_LIMIT = 2000.0 # global limit ‖x‖²
|
|
||||||
TEXT_LEN = 40 # must stay 40
|
|
||||||
|
|
||||||
##############################################################################
|
G = 10.0
|
||||||
# Hadamard-codebook utilities
|
ENERGY_LIMIT = 2000.0
|
||||||
##############################################################################
|
TEXT_LENGTH = 40
|
||||||
def _hadamard(r: int) -> np.ndarray:
|
ALPHABET_LENGTH = len(ALPHABET)
|
||||||
|
|
||||||
|
assert ALPHABET_LENGTH == 64
|
||||||
|
|
||||||
|
def pair_to_index(a: str, b: str) -> int:
|
||||||
|
i1 = CHAR_TO_IDX[a]
|
||||||
|
i2 = CHAR_TO_IDX[b]
|
||||||
|
return (i1 << 6) + i2
|
||||||
|
|
||||||
|
def hadamard(r: int) -> np.ndarray:
|
||||||
if r == 0:
|
if r == 0:
|
||||||
return np.array([[1.]], dtype=np.float32)
|
return np.array([[1.]], dtype=np.float32)
|
||||||
H = _hadamard(r-1)
|
H = hadamard(r - 1)
|
||||||
return np.block([[H, H],
|
return np.block([[H, H], [H, -H]])
|
||||||
[H, -H]])
|
|
||||||
|
|
||||||
def _Br(r: int) -> np.ndarray:
|
def Br(r: int) -> np.ndarray:
|
||||||
H = _hadamard(r)
|
H = hadamard(r)
|
||||||
return np.vstack([H, -H]) # 2^(r+1) × 2^r
|
return np.vstack([H, -H])
|
||||||
|
|
||||||
##############################################################################
|
def make_codebook(r: int = 11, num_blocks: int = TEXT_LENGTH // 2, Eb: float = ENERGY_LIMIT):
|
||||||
# Public API
|
B = Br(r)
|
||||||
##############################################################################
|
C = np.hstack((B, B)).astype(np.float32)
|
||||||
def make_codebook(r: int = 11,
|
n = C.shape[1]
|
||||||
num_blocks: int = TEXT_LEN//2,
|
alpha = Eb / (num_blocks * 2 * n)
|
||||||
energy_budget: float = ENERGY_LIMIT
|
|
||||||
) -> Tuple[np.ndarray, float]:
|
|
||||||
"""
|
|
||||||
Builds the scaled codebook C (4096×4096) used by both encoder & decoder.
|
|
||||||
|
|
||||||
α is chosen so that **after the per-sample duplication** in encode_message,
|
|
||||||
a full 20-block message consumes exactly `energy_budget`.
|
|
||||||
"""
|
|
||||||
B = _Br(r) # 4096 × 2048
|
|
||||||
C = np.hstack((B, B)).astype(np.float32) # 4096 × 4096
|
|
||||||
n = C.shape[1] # 4096
|
|
||||||
dup_factor = 2 # sample-duplication
|
|
||||||
alpha = energy_budget / (num_blocks * dup_factor * n)
|
|
||||||
C *= np.sqrt(alpha, dtype=C.dtype)
|
C *= np.sqrt(alpha, dtype=C.dtype)
|
||||||
return C, alpha
|
return C, alpha
|
||||||
|
|
||||||
def pair_to_index(a: str, b: str) -> int:
|
|
||||||
return 64 * CHAR_TO_IDX[a] + CHAR_TO_IDX[b]
|
|
||||||
|
|
||||||
def index_to_pair(k: int) -> tuple[str, str]:
|
|
||||||
if not 0 <= k < 4096:
|
|
||||||
raise ValueError("index out of range [0,4095]")
|
|
||||||
return IDX_TO_CHAR[k // 64], IDX_TO_CHAR[k % 64]
|
|
||||||
|
|
||||||
def encode_message(msg: str, C: np.ndarray) -> np.ndarray:
|
def encode_message(msg: str, C: np.ndarray) -> np.ndarray:
|
||||||
"""
|
if len(msg) != TEXT_LENGTH:
|
||||||
Encode a 40-character message. Each 2-character pair → one codeword row.
|
|
||||||
After concatenation the whole signal is duplicated sample-wise so that
|
|
||||||
the channel’s even / odd indices each carry one clean copy.
|
|
||||||
"""
|
|
||||||
if len(msg) != TEXT_LEN:
|
|
||||||
raise ValueError("Message must be exactly 40 characters.")
|
raise ValueError("Message must be exactly 40 characters.")
|
||||||
|
|
||||||
pairs = [(msg[i], msg[i+1]) for i in range(0, TEXT_LEN, 2)]
|
idxs = [CHAR_TO_IDX[c] for c in msg]
|
||||||
rows = [C[pair_to_index(a, b)] for a, b in pairs] # 20×4096
|
pair_idxs = [(idxs[i] << 6) | idxs[i+1] for i in range(0, TEXT_LENGTH, 2)]
|
||||||
signal = np.concatenate(rows).astype(np.float32)
|
|
||||||
signal = np.repeat(signal, 2) # duplicate
|
signal = np.repeat(C[pair_idxs].ravel(), 2).astype(np.float32)
|
||||||
# tight numeric safety-check (≡ 2000, barring float error)
|
|
||||||
e = np.sum(signal**2)
|
if not np.isclose(signal.dot(signal), ENERGY_LIMIT, atol=1e-3):
|
||||||
if not np.isclose(e, ENERGY_LIMIT, atol=1e-3):
|
raise RuntimeError("energy check failed")
|
||||||
raise RuntimeError(f"energy sanity check failed ({e:.3f} ≠ 2000)")
|
|
||||||
return signal
|
return signal
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
if len(sys.argv) > 1:
|
||||||
|
msg = sys.argv[1]
|
||||||
|
else:
|
||||||
|
msg = input(f"Enter a message ({TEXT_LENGTH} characters): ").strip()
|
||||||
|
if len(msg) != TEXT_LENGTH:
|
||||||
|
print(f"Message must be exactly {TEXT_LENGTH} characters.")
|
||||||
|
sys.exit(1)
|
||||||
|
C, _ = make_codebook()
|
||||||
|
signal = encode_message(msg, C)
|
||||||
|
np.savetxt("input.txt", signal)
|
||||||
|
print("Signal written to input.txt")
|
|
@ -1,57 +0,0 @@
|
||||||
# encoder_backup.py
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# System parameters
|
|
||||||
G = 10.0 # power gain for even samples
|
|
||||||
ENERGY_LIMIT = 2000.0 # total energy per block
|
|
||||||
ALPHABET = (
|
|
||||||
[chr(i) for i in range(ord('a'), ord('z')+1)] +
|
|
||||||
[chr(i) for i in range(ord('A'), ord('Z')+1)] +
|
|
||||||
[str(i) for i in range(10)] +
|
|
||||||
[' ', '.']
|
|
||||||
)
|
|
||||||
assert len(ALPHABET) == 64, "Alphabet must be size 64"
|
|
||||||
|
|
||||||
|
|
||||||
def hadamard(r: int) -> np.ndarray:
|
|
||||||
if r == 0:
|
|
||||||
return np.array([[1]], dtype=float)
|
|
||||||
M = hadamard(r-1)
|
|
||||||
return np.block([[M, M], [M, -M]])
|
|
||||||
|
|
||||||
|
|
||||||
def Br(r: int) -> np.ndarray:
|
|
||||||
M = hadamard(r)
|
|
||||||
return np.vstack([M, -M])
|
|
||||||
|
|
||||||
|
|
||||||
def make_codebook(r: int, num_blocks: int) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Build 64x64 codebook and scale blocks so energy per block ≤ ENERGY_LIMIT/num_blocks
|
|
||||||
"""
|
|
||||||
B_full = Br(r) #
|
|
||||||
B = B_full[: len(ALPHABET), :] # now shape (64, 2^r)
|
|
||||||
C = np.hstack([B, B]).astype(float) # shape (64, 2^{r+1})
|
|
||||||
raw_norm = np.sum(C[0]**2)
|
|
||||||
margin = 0.99
|
|
||||||
alpha = margin * (ENERGY_LIMIT / num_blocks) / raw_norm
|
|
||||||
return np.sqrt(alpha) * C
|
|
||||||
|
|
||||||
|
|
||||||
def interleave(c: np.ndarray) -> np.ndarray:
|
|
||||||
half = c.size // 2
|
|
||||||
x = np.empty(c.size)
|
|
||||||
x[0::2] = c[:half]
|
|
||||||
x[1::2] = c[half:]
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def encode_message(msg: str, C: np.ndarray) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Encode 40-character message into interleaved code symbols
|
|
||||||
"""
|
|
||||||
if len(msg) != 40:
|
|
||||||
raise ValueError("Message must be exactly 40 characters.")
|
|
||||||
idx = [ALPHABET.index(ch) for ch in msg]
|
|
||||||
blocks = [interleave(C[i]) for i in idx]
|
|
||||||
return np.concatenate(blocks)
|
|
97
main.py
97
main.py
|
@ -1,65 +1,70 @@
|
||||||
#!/usr/bin/env python3
|
import argparse
|
||||||
"""
|
import numpy as np
|
||||||
unchanged CLI – only the *r* and the assert on message length moved to encoder
|
|
||||||
"""
|
|
||||||
import argparse, sys, numpy as np, subprocess, pathlib, os, tempfile
|
|
||||||
import encoder, decoder, channel
|
import encoder, decoder, channel
|
||||||
|
import subprocess, pathlib, os, tempfile, sys
|
||||||
|
|
||||||
INPUT_FILE = OUTPUT_FILE = None
|
|
||||||
|
|
||||||
def transmit(msg, C): return encoder.encode_message(msg, C)
|
def transmit(msg, C):
|
||||||
def receive_local(x): return channel.channel(x)
|
return encoder.encode_message(msg, C)
|
||||||
|
|
||||||
def receive_server(x, host, port):
|
def receive_local(x):
|
||||||
global INPUT_FILE, OUTPUT_FILE
|
return channel.channel(x)
|
||||||
if INPUT_FILE:
|
|
||||||
in_f, rm_in = INPUT_FILE, False
|
|
||||||
np.savetxt(in_f, x)
|
|
||||||
else:
|
|
||||||
tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
|
|
||||||
np.savetxt(tf.name, x); tf.close()
|
|
||||||
in_f, rm_in = tf.name, True
|
|
||||||
if OUTPUT_FILE:
|
|
||||||
out_f, rm_out = OUTPUT_FILE, False
|
|
||||||
else:
|
|
||||||
fd, out_f = tempfile.mkstemp(suffix='.txt'); os.close(fd); rm_out = True
|
|
||||||
|
|
||||||
cmd = [sys.executable, str(pathlib.Path(__file__).parent/'client.py'),
|
def receive_server(x, host, port, input_file=None, output_file=None):
|
||||||
'--input_file', in_f, '--output_file', out_f,
|
in_f = input_file or tempfile.NamedTemporaryFile(suffix='.txt', delete=False).name
|
||||||
'--srv_hostname', host, '--srv_port', str(port)]
|
np.savetxt(in_f, x)
|
||||||
|
|
||||||
|
out_f = output_file or tempfile.mkstemp(suffix='.txt')[1]
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
sys.executable, str(pathlib.Path(__file__).parent / 'client.py'),
|
||||||
|
'--input_file', in_f, '--output_file', out_f,
|
||||||
|
'--srv_hostname', host, '--srv_port', str(port)
|
||||||
|
]
|
||||||
try:
|
try:
|
||||||
subprocess.run(cmd, check=True)
|
subprocess.run(cmd, check=True)
|
||||||
Y = np.loadtxt(out_f)
|
Y = np.loadtxt(out_f)
|
||||||
finally:
|
finally:
|
||||||
if rm_in and os.path.exists(in_f): os.remove(in_f)
|
if not input_file and os.path.exists(in_f):
|
||||||
if rm_out and os.path.exists(out_f): os.remove(out_f)
|
os.remove(in_f)
|
||||||
|
if not output_file and os.path.exists(out_f):
|
||||||
|
os.remove(out_f)
|
||||||
return Y
|
return Y
|
||||||
|
|
||||||
def receive(x, mode, host, port):
|
|
||||||
return receive_local(x) if mode=='local' else receive_server(x,host,port)
|
|
||||||
|
|
||||||
def test(msg, n_trials, mode, host, port):
|
def receive(x, mode, host, port, input_file=None, output_file=None):
|
||||||
C, _ = encoder.make_codebook() # r=11 by default
|
if mode == 'local':
|
||||||
|
return receive_local(x)
|
||||||
|
return receive_server(x, host, port, input_file, output_file)
|
||||||
|
|
||||||
|
|
||||||
|
def test(msg, n_trials, mode, host, port, input_file=None, output_file=None):
|
||||||
|
C, _ = encoder.make_codebook()
|
||||||
print(f"Using codebook with {C.shape[0]} codewords, {C.shape[1]} symbols each.")
|
print(f"Using codebook with {C.shape[0]} codewords, {C.shape[1]} symbols each.")
|
||||||
ok = 0
|
ok = 0
|
||||||
for _ in range(n_trials):
|
for _ in range(n_trials):
|
||||||
x = transmit(msg, C)
|
x = transmit(msg, C)
|
||||||
print(f"Transmitted {len(x):,} samples (energy={np.dot(x,x):.2f})")
|
print(f"Transmitted {len(x):,} samples (energy={np.dot(x, x):.2f})")
|
||||||
y = receive(x, mode, host, port)
|
y = receive(x, mode, host, port, input_file, output_file)
|
||||||
print(f"Received {len(y):,} samples (energy={np.dot(y,y):.2f})")
|
print(f"Received {len(y):,} samples (energy={np.dot(y, y):.2f})")
|
||||||
est, _ = decoder.decode_blocks_with_state(y, C)
|
est, _ = decoder.decode_blocks_with_state(y, C)
|
||||||
if est == msg: ok += 1
|
if est == msg:
|
||||||
print(f"{ok}/{n_trials} perfect decodes ({100*ok/n_trials:.2f}%)")
|
ok += 1
|
||||||
|
print(f"{ok}/{n_trials} perfect decodes ({100 * ok / n_trials:.2f}%)")
|
||||||
|
|
||||||
def _args():
|
|
||||||
p=argparse.ArgumentParser()
|
def parse_args():
|
||||||
p.add_argument('-m','--message',required=True,help='exactly 40 chars')
|
p = argparse.ArgumentParser()
|
||||||
p.add_argument('-n','--trials', type=int, default=200)
|
p.add_argument('-m', '--message', required=True, help='exactly 40 chars')
|
||||||
p.add_argument('--mode',choices=['local','server'],default='local')
|
p.add_argument('-n', '--trials', type=int, default=200)
|
||||||
p.add_argument('--hostname',default='iscsrv72.epfl.ch'); p.add_argument('--port',type=int,default=80)
|
p.add_argument('--mode', choices=['local', 'server'], default='local')
|
||||||
p.add_argument('-i','--input_file'); p.add_argument('-o','--output_file')
|
p.add_argument('--hostname', default='iscsrv72.epfl.ch')
|
||||||
|
p.add_argument('--port', type=int, default=80)
|
||||||
|
p.add_argument('-i', '--input_file')
|
||||||
|
p.add_argument('-o', '--output_file')
|
||||||
return p.parse_args()
|
return p.parse_args()
|
||||||
|
|
||||||
if __name__=='__main__':
|
|
||||||
a=_args(); INPUT_FILE=a.input_file; OUTPUT_FILE=a.output_file
|
if __name__ == '__main__':
|
||||||
test(a.message, a.trials, a.mode, a.hostname, a.port)
|
a = parse_args()
|
||||||
|
test(a.message, a.trials, a.mode, a.hostname, a.port, a.input_file, a.output_file)
|
||||||
|
|
119
main_backup.py
119
main_backup.py
|
@ -1,119 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import sys
|
|
||||||
import numpy as np
|
|
||||||
import encoder
|
|
||||||
import decoder
|
|
||||||
import channel
|
|
||||||
import subprocess
|
|
||||||
import pathlib
|
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
# Global paths for debugging
|
|
||||||
INPUT_FILE = None
|
|
||||||
OUTPUT_FILE = None
|
|
||||||
|
|
||||||
def transmit(msg, C):
|
|
||||||
return encoder.encode_message(msg, C)
|
|
||||||
|
|
||||||
def receive_local(c):
|
|
||||||
return channel.channel(c)
|
|
||||||
|
|
||||||
def receive_server(c, hostname, port):
|
|
||||||
global INPUT_FILE, OUTPUT_FILE
|
|
||||||
# write or temp-file for input
|
|
||||||
if INPUT_FILE:
|
|
||||||
in_name, delete_in = INPUT_FILE, False
|
|
||||||
np.savetxt(in_name, c)
|
|
||||||
else:
|
|
||||||
tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
|
|
||||||
np.savetxt(tf.name, c)
|
|
||||||
in_name, delete_in = tf.name, True
|
|
||||||
tf.close()
|
|
||||||
# write or temp-file for output
|
|
||||||
if OUTPUT_FILE:
|
|
||||||
out_name, delete_out = OUTPUT_FILE, False
|
|
||||||
else:
|
|
||||||
fd, out_name = tempfile.mkstemp(suffix='.txt')
|
|
||||||
os.close(fd)
|
|
||||||
delete_out = True
|
|
||||||
|
|
||||||
# call client.py
|
|
||||||
cmd = [
|
|
||||||
sys.executable,
|
|
||||||
str(pathlib.Path(__file__).parent / 'client.py'),
|
|
||||||
'--input_file', in_name,
|
|
||||||
'--output_file', out_name,
|
|
||||||
'--srv_hostname', hostname,
|
|
||||||
'--srv_port', str(port)
|
|
||||||
]
|
|
||||||
try:
|
|
||||||
subprocess.run(cmd, check=True)
|
|
||||||
Y = np.loadtxt(out_name)
|
|
||||||
finally:
|
|
||||||
if delete_in and os.path.exists(in_name): os.remove(in_name)
|
|
||||||
if delete_out and os.path.exists(out_name): os.remove(out_name)
|
|
||||||
|
|
||||||
return Y
|
|
||||||
|
|
||||||
def receive(c, mode, hostname, port):
|
|
||||||
if mode == 'local':
|
|
||||||
return receive_local(c)
|
|
||||||
elif mode == 'server':
|
|
||||||
return receive_server(c, hostname, port)
|
|
||||||
else:
|
|
||||||
raise ValueError("mode must be 'local' or 'server'")
|
|
||||||
|
|
||||||
def test_performance(msg, num_trials, mode, hostname, port):
|
|
||||||
if len(msg) != 40:
|
|
||||||
raise ValueError('Message must be exactly 40 characters.')
|
|
||||||
|
|
||||||
# build the same codebook
|
|
||||||
C = encoder.make_codebook(r=11, num_blocks=40)
|
|
||||||
|
|
||||||
successes = 0
|
|
||||||
print(f"Original message: {msg!r}")
|
|
||||||
print(f"Running {num_trials} trials over '{mode}' channel...\n")
|
|
||||||
|
|
||||||
for i in range(num_trials):
|
|
||||||
# TX → channel → RX
|
|
||||||
c = transmit(msg, C)
|
|
||||||
Y = receive(c, mode, hostname, port)
|
|
||||||
|
|
||||||
# decode with state‐estimation
|
|
||||||
est, s_est = decoder.decode_blocks_with_state(Y, C)
|
|
||||||
|
|
||||||
# count char errors
|
|
||||||
errors = sum(1 for a,b in zip(est, msg) if a!=b)
|
|
||||||
if errors == 0:
|
|
||||||
successes += 1
|
|
||||||
|
|
||||||
print(f"Trial {i+1:3d}: state={s_est} decoded={est!r} errors={errors}")
|
|
||||||
|
|
||||||
pct = 100 * successes / num_trials
|
|
||||||
print("\n=== Summary ===")
|
|
||||||
print(f" Total trials: {num_trials}")
|
|
||||||
print(f" Perfect decodings: {successes}")
|
|
||||||
print(f" Overall accuracy: {pct:.2f}%")
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
p = argparse.ArgumentParser(description='Test comms system')
|
|
||||||
p.add_argument('--message','-m', required=True, help='40-char message')
|
|
||||||
p.add_argument('--trials','-n', type=int, default=200, help='Number of trials')
|
|
||||||
p.add_argument('--mode', choices=['local','server'], default='local')
|
|
||||||
p.add_argument('--hostname', default='iscsrv72.epfl.ch')
|
|
||||||
p.add_argument('--port', type=int, default=80)
|
|
||||||
p.add_argument('--input_file','-i', default=None,
|
|
||||||
help='(server mode) where to write input.txt')
|
|
||||||
p.add_argument('--output_file','-o',default=None,
|
|
||||||
help='(server mode) where to write output.txt')
|
|
||||||
return p.parse_args()
|
|
||||||
|
|
||||||
if __name__=='__main__':
|
|
||||||
args = parse_args()
|
|
||||||
INPUT_FILE = args.input_file
|
|
||||||
OUTPUT_FILE = args.output_file
|
|
||||||
test_performance(args.message, args.trials,
|
|
||||||
args.mode, args.hostname, args.port)
|
|
Loading…
Add table
Reference in a new issue