feat: improve code
This commit is contained in:
parent
18aca25240
commit
fdb07e1c00
8 changed files with 159 additions and 417 deletions
18
README.md
18
README.md
|
@ -66,6 +66,24 @@ python3 main_backup.py \
|
|||
|
||||
> After running, `input.txt` contains your transmitted samples, and `output.txt` contains the noisy output from the server.
|
||||
|
||||
### 4. Create input.txt
|
||||
|
||||
```bash
|
||||
python3 encoder.py "message_40_characters"
|
||||
```
|
||||
|
||||
### 5. Create output.txt throught the channel
|
||||
|
||||
```bash
|
||||
python3 channel.py
|
||||
```
|
||||
### 6. Decode the output.tct
|
||||
|
||||
```bash
|
||||
python3 decoder.py
|
||||
```
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Manual decoding of server output
|
||||
|
|
|
@ -18,3 +18,12 @@ def channel(x):
|
|||
Y[::2] += x_even
|
||||
Y[1::2] += x_odd
|
||||
return Y
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Read input.txt
|
||||
x = np.loadtxt("input.txt")
|
||||
# Pass through channel
|
||||
y = channel(x)
|
||||
# Write output.txt
|
||||
np.savetxt("output.txt", y)
|
||||
print("Channel output written to output.txt")
|
56
decoder.py
56
decoder.py
|
@ -1,43 +1,49 @@
|
|||
import numpy as np
|
||||
|
||||
from encoder import (
|
||||
G, CHAR_SET, pair_to_index, index_to_pair, make_codebook
|
||||
)
|
||||
print(np.__config__.show())
|
||||
from encoder import G, ALPHABET, make_codebook
|
||||
|
||||
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray):
|
||||
n_cw = C.shape[1] * 2 # 8192 samples after duplication
|
||||
nb = Y.size // n_cw
|
||||
assert nb * n_cw == Y.size, "length mismatch"
|
||||
N = Y.size
|
||||
assert N % 2 == 0
|
||||
assert N <= 1_000_000
|
||||
|
||||
n_cw = C.shape[1] * 2
|
||||
nb = Y.size // n_cw
|
||||
assert nb * n_cw == Y.size, "length mismatch between Y and codebook"
|
||||
|
||||
Y_blocks = Y.reshape(nb, n_cw)
|
||||
Y_even = Y_blocks[:, ::2].astype(np.float32)
|
||||
Y_odd = Y_blocks[:, 1::2].astype(np.float32)
|
||||
C = C.astype(np.float32)
|
||||
|
||||
sqrtg = np.sqrt(G)
|
||||
|
||||
# --- Vectorize even/odd block extraction done for speedup because nupy is dumb
|
||||
Y_blocks = Y.reshape(nb, n_cw)
|
||||
Y_even = Y_blocks[:, ::2] # shape (nb, 4096)
|
||||
Y_odd = Y_blocks[:, 1::2] # shape (nb, 4096)
|
||||
#print(f"Extracted {nb} blocks of {n_cw} samples each, ")
|
||||
|
||||
# --- Vectorize all block scoring
|
||||
# Each: (nb, 4096) = (nb, 4096) @ (4096, 4096).T
|
||||
#print(Y_blocks.dtype, C.dtype, sqrtg.dtype, Y_even.dtype, Y_odd.dtype)
|
||||
|
||||
Y_even = Y_even.astype(np.float32, copy=False)
|
||||
Y_odd = Y_odd.astype(np.float32, copy=False)
|
||||
C = C.astype(np.float32, copy=False)
|
||||
s1 = sqrtg * (Y_even @ C.T) + (Y_odd @ C.T)
|
||||
s2 = (Y_even @ C.T) + sqrtg * (Y_odd @ C.T)
|
||||
#print(f"Scoring {nb} blocks with {C.shape[0]} codewords, each of length {C.shape[1]}.")
|
||||
|
||||
best_if_s1 = np.argmax(s1, axis=1)
|
||||
best_if_s2 = np.argmax(s2, axis=1)
|
||||
|
||||
tot1 = np.sum(np.max(s1, axis=1))
|
||||
tot2 = np.sum(np.max(s2, axis=1))
|
||||
|
||||
state = 1 if tot1 >= tot2 else 2
|
||||
indices = best_if_s1 if state == 1 else best_if_s2
|
||||
chars = [index_to_pair(i) for i in indices]
|
||||
decoded = ''.join(a+b for a, b in chars)
|
||||
|
||||
msg_chars = []
|
||||
for idx in indices:
|
||||
first = ALPHABET[idx >> 6]
|
||||
second = ALPHABET[idx & 0x3F]
|
||||
msg_chars.append(first)
|
||||
msg_chars.append(second)
|
||||
|
||||
decoded = ''.join(msg_chars)
|
||||
return decoded, state
|
||||
|
||||
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
|
||||
return decode_blocks_with_state(Y, C)[0]
|
||||
|
||||
if __name__ == "__main__":
|
||||
Y = np.loadtxt("output.txt")
|
||||
C, _ = make_codebook()
|
||||
decoded, state = decode_blocks_with_state(Y, C)
|
||||
print(f"Decoded message: {decoded}")
|
||||
print(f"Detected channel state: {state}")
|
|
@ -1,103 +0,0 @@
|
|||
# decoder_backup.py
|
||||
import numpy as np
|
||||
from encoder_backup import ALPHABET, G
|
||||
|
||||
# Match the channel’s noise variance
|
||||
SIGMA2 = 10.0
|
||||
|
||||
def _block_ll_scores(Yb: np.ndarray,
|
||||
C1: np.ndarray,
|
||||
C2: np.ndarray,
|
||||
sqrtG: float
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Compute per-symbol log-likelihood scores for one interleaved block Yb
|
||||
under the two channel states (even-boosted vs. odd-boosted).
|
||||
Returns (scores_state1, scores_state2).
|
||||
"""
|
||||
|
||||
# Split received block into even/odd samples
|
||||
Ye, Yo = Yb[0::2], Yb[1::2]
|
||||
|
||||
# Precompute the squared-norm penalties for each codeword half
|
||||
# (these come from the -||Y - H_s C||^2 term)
|
||||
# state1: even half is √G * C1, odd half is C2
|
||||
E1 = G * np.sum(C1**2, axis=1) + np.sum(C2**2, axis=1)
|
||||
# state2: even half is C1, odd half is √G * C2
|
||||
E2 = np.sum(C1**2, axis=1) + G * np.sum(C2**2, axis=1)
|
||||
|
||||
# Correlation terms
|
||||
corr1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
|
||||
corr2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
|
||||
|
||||
# ML log‐likelihood (up to constant −1/(2σ²)):
|
||||
scores1 = (corr1 - 0.5 * E1) / SIGMA2
|
||||
scores2 = (corr2 - 0.5 * E2) / SIGMA2
|
||||
|
||||
return scores1, scores2
|
||||
|
||||
|
||||
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
|
||||
"""
|
||||
Per-block ML decoding marginalizing over the unknown state:
|
||||
for each block, compute scores1/2 via _block_ll_scores, then
|
||||
marginal_score = logaddexp(scores1, scores2) and pick argmax.
|
||||
"""
|
||||
n = C.shape[1]
|
||||
assert Y.size % n == 0, "Y length must be a multiple of codeword length"
|
||||
num_blocks = Y.size // n
|
||||
|
||||
half = n // 2
|
||||
C1, C2 = C[:, :half], C[:, half:]
|
||||
sqrtG = np.sqrt(G)
|
||||
|
||||
recovered = []
|
||||
for k in range(num_blocks):
|
||||
Yb = Y[k*n:(k+1)*n]
|
||||
s1, s2 = _block_ll_scores(Yb, C1, C2, sqrtG)
|
||||
# marginal log-likelihood per symbol
|
||||
marg = np.logaddexp(s1, s2)
|
||||
best = int(np.argmax(marg))
|
||||
recovered.append(ALPHABET[best])
|
||||
|
||||
return "".join(recovered)
|
||||
|
||||
|
||||
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray) -> (str, int):
|
||||
"""
|
||||
Joint‐ML state estimation and decoding:
|
||||
- For each block, get per-state scores via _block_ll_scores
|
||||
- Pick the best symbol index under each state; sum those log‐L’s across blocks
|
||||
- Choose the state with the higher total log‐L
|
||||
- Reconstruct the string using the best‐symbol indices for that state
|
||||
Returns (decoded_string, estimated_state).
|
||||
"""
|
||||
n = C.shape[1]
|
||||
assert Y.size % n == 0, "Y length must be a multiple of codeword length"
|
||||
num_blocks = Y.size // n
|
||||
|
||||
half = n // 2
|
||||
C1, C2 = C[:, :half], C[:, half:]
|
||||
sqrtG = np.sqrt(G)
|
||||
|
||||
total1, total2 = 0.0, 0.0
|
||||
best1, best2 = [], []
|
||||
|
||||
for k in range(num_blocks):
|
||||
Yb = Y[k*n:(k+1)*n]
|
||||
s1, s2 = _block_ll_scores(Yb, C1, C2, sqrtG)
|
||||
|
||||
idx1 = int(np.argmax(s1))
|
||||
idx2 = int(np.argmax(s2))
|
||||
|
||||
total1 += s1[idx1]
|
||||
total2 += s2[idx2]
|
||||
|
||||
best1.append(idx1)
|
||||
best2.append(idx2)
|
||||
|
||||
s_est = 1 if total1 >= total2 else 2
|
||||
chosen = best1 if s_est == 1 else best2
|
||||
decoded = "".join(ALPHABET[i] for i in chosen)
|
||||
|
||||
return decoded, s_est
|
117
encoder.py
117
encoder.py
|
@ -1,82 +1,65 @@
|
|||
# encoder.py — 2-char/1-codeword implementation
|
||||
import numpy as np
|
||||
from typing import Tuple
|
||||
import sys
|
||||
|
||||
##############################################################################
|
||||
# Public constants
|
||||
##############################################################################
|
||||
CHAR_SET = (
|
||||
[chr(i) for i in range(ord('a'), ord('z')+1)] +
|
||||
[chr(i) for i in range(ord('A'), ord('Z')+1)] +
|
||||
[str(i) for i in range(10)] +
|
||||
[' ', '.']
|
||||
)
|
||||
assert len(CHAR_SET) == 64
|
||||
CHAR_TO_IDX = {c: i for i, c in enumerate(CHAR_SET)}
|
||||
IDX_TO_CHAR = {i: c for c, i in CHAR_TO_IDX.items()}
|
||||
ALPHABET = "abcdefghijklmnopqrstuvwxyz" \
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" \
|
||||
"0123456789 ."
|
||||
|
||||
G = 10.0 # channel gain
|
||||
ENERGY_LIMIT = 2000.0 # global limit ‖x‖²
|
||||
TEXT_LEN = 40 # must stay 40
|
||||
CHAR_TO_IDX = {c: i for i, c in enumerate(ALPHABET)}
|
||||
|
||||
##############################################################################
|
||||
# Hadamard-codebook utilities
|
||||
##############################################################################
|
||||
def _hadamard(r: int) -> np.ndarray:
|
||||
G = 10.0
|
||||
ENERGY_LIMIT = 2000.0
|
||||
TEXT_LENGTH = 40
|
||||
ALPHABET_LENGTH = len(ALPHABET)
|
||||
|
||||
assert ALPHABET_LENGTH == 64
|
||||
|
||||
def pair_to_index(a: str, b: str) -> int:
|
||||
i1 = CHAR_TO_IDX[a]
|
||||
i2 = CHAR_TO_IDX[b]
|
||||
return (i1 << 6) + i2
|
||||
|
||||
def hadamard(r: int) -> np.ndarray:
|
||||
if r == 0:
|
||||
return np.array([[1.]], dtype=np.float32)
|
||||
H = _hadamard(r-1)
|
||||
return np.block([[H, H],
|
||||
[H, -H]])
|
||||
H = hadamard(r - 1)
|
||||
return np.block([[H, H], [H, -H]])
|
||||
|
||||
def _Br(r: int) -> np.ndarray:
|
||||
H = _hadamard(r)
|
||||
return np.vstack([H, -H]) # 2^(r+1) × 2^r
|
||||
def Br(r: int) -> np.ndarray:
|
||||
H = hadamard(r)
|
||||
return np.vstack([H, -H])
|
||||
|
||||
##############################################################################
|
||||
# Public API
|
||||
##############################################################################
|
||||
def make_codebook(r: int = 11,
|
||||
num_blocks: int = TEXT_LEN//2,
|
||||
energy_budget: float = ENERGY_LIMIT
|
||||
) -> Tuple[np.ndarray, float]:
|
||||
"""
|
||||
Builds the scaled codebook C (4096×4096) used by both encoder & decoder.
|
||||
|
||||
α is chosen so that **after the per-sample duplication** in encode_message,
|
||||
a full 20-block message consumes exactly `energy_budget`.
|
||||
"""
|
||||
B = _Br(r) # 4096 × 2048
|
||||
C = np.hstack((B, B)).astype(np.float32) # 4096 × 4096
|
||||
n = C.shape[1] # 4096
|
||||
dup_factor = 2 # sample-duplication
|
||||
alpha = energy_budget / (num_blocks * dup_factor * n)
|
||||
def make_codebook(r: int = 11, num_blocks: int = TEXT_LENGTH // 2, Eb: float = ENERGY_LIMIT):
|
||||
B = Br(r)
|
||||
C = np.hstack((B, B)).astype(np.float32)
|
||||
n = C.shape[1]
|
||||
alpha = Eb / (num_blocks * 2 * n)
|
||||
C *= np.sqrt(alpha, dtype=C.dtype)
|
||||
return C, alpha
|
||||
|
||||
def pair_to_index(a: str, b: str) -> int:
|
||||
return 64 * CHAR_TO_IDX[a] + CHAR_TO_IDX[b]
|
||||
|
||||
def index_to_pair(k: int) -> tuple[str, str]:
|
||||
if not 0 <= k < 4096:
|
||||
raise ValueError("index out of range [0,4095]")
|
||||
return IDX_TO_CHAR[k // 64], IDX_TO_CHAR[k % 64]
|
||||
|
||||
def encode_message(msg: str, C: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Encode a 40-character message. Each 2-character pair → one codeword row.
|
||||
After concatenation the whole signal is duplicated sample-wise so that
|
||||
the channel’s even / odd indices each carry one clean copy.
|
||||
"""
|
||||
if len(msg) != TEXT_LEN:
|
||||
if len(msg) != TEXT_LENGTH:
|
||||
raise ValueError("Message must be exactly 40 characters.")
|
||||
|
||||
pairs = [(msg[i], msg[i+1]) for i in range(0, TEXT_LEN, 2)]
|
||||
rows = [C[pair_to_index(a, b)] for a, b in pairs] # 20×4096
|
||||
signal = np.concatenate(rows).astype(np.float32)
|
||||
signal = np.repeat(signal, 2) # duplicate
|
||||
# tight numeric safety-check (≡ 2000, barring float error)
|
||||
e = np.sum(signal**2)
|
||||
if not np.isclose(e, ENERGY_LIMIT, atol=1e-3):
|
||||
raise RuntimeError(f"energy sanity check failed ({e:.3f} ≠ 2000)")
|
||||
idxs = [CHAR_TO_IDX[c] for c in msg]
|
||||
pair_idxs = [(idxs[i] << 6) | idxs[i+1] for i in range(0, TEXT_LENGTH, 2)]
|
||||
|
||||
signal = np.repeat(C[pair_idxs].ravel(), 2).astype(np.float32)
|
||||
|
||||
if not np.isclose(signal.dot(signal), ENERGY_LIMIT, atol=1e-3):
|
||||
raise RuntimeError("energy check failed")
|
||||
|
||||
return signal
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1:
|
||||
msg = sys.argv[1]
|
||||
else:
|
||||
msg = input(f"Enter a message ({TEXT_LENGTH} characters): ").strip()
|
||||
if len(msg) != TEXT_LENGTH:
|
||||
print(f"Message must be exactly {TEXT_LENGTH} characters.")
|
||||
sys.exit(1)
|
||||
C, _ = make_codebook()
|
||||
signal = encode_message(msg, C)
|
||||
np.savetxt("input.txt", signal)
|
||||
print("Signal written to input.txt")
|
|
@ -1,57 +0,0 @@
|
|||
# encoder_backup.py
|
||||
import numpy as np
|
||||
|
||||
# System parameters
|
||||
G = 10.0 # power gain for even samples
|
||||
ENERGY_LIMIT = 2000.0 # total energy per block
|
||||
ALPHABET = (
|
||||
[chr(i) for i in range(ord('a'), ord('z')+1)] +
|
||||
[chr(i) for i in range(ord('A'), ord('Z')+1)] +
|
||||
[str(i) for i in range(10)] +
|
||||
[' ', '.']
|
||||
)
|
||||
assert len(ALPHABET) == 64, "Alphabet must be size 64"
|
||||
|
||||
|
||||
def hadamard(r: int) -> np.ndarray:
|
||||
if r == 0:
|
||||
return np.array([[1]], dtype=float)
|
||||
M = hadamard(r-1)
|
||||
return np.block([[M, M], [M, -M]])
|
||||
|
||||
|
||||
def Br(r: int) -> np.ndarray:
|
||||
M = hadamard(r)
|
||||
return np.vstack([M, -M])
|
||||
|
||||
|
||||
def make_codebook(r: int, num_blocks: int) -> np.ndarray:
|
||||
"""
|
||||
Build 64x64 codebook and scale blocks so energy per block ≤ ENERGY_LIMIT/num_blocks
|
||||
"""
|
||||
B_full = Br(r) #
|
||||
B = B_full[: len(ALPHABET), :] # now shape (64, 2^r)
|
||||
C = np.hstack([B, B]).astype(float) # shape (64, 2^{r+1})
|
||||
raw_norm = np.sum(C[0]**2)
|
||||
margin = 0.99
|
||||
alpha = margin * (ENERGY_LIMIT / num_blocks) / raw_norm
|
||||
return np.sqrt(alpha) * C
|
||||
|
||||
|
||||
def interleave(c: np.ndarray) -> np.ndarray:
|
||||
half = c.size // 2
|
||||
x = np.empty(c.size)
|
||||
x[0::2] = c[:half]
|
||||
x[1::2] = c[half:]
|
||||
return x
|
||||
|
||||
|
||||
def encode_message(msg: str, C: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Encode 40-character message into interleaved code symbols
|
||||
"""
|
||||
if len(msg) != 40:
|
||||
raise ValueError("Message must be exactly 40 characters.")
|
||||
idx = [ALPHABET.index(ch) for ch in msg]
|
||||
blocks = [interleave(C[i]) for i in idx]
|
||||
return np.concatenate(blocks)
|
97
main.py
97
main.py
|
@ -1,65 +1,70 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
unchanged CLI – only the *r* and the assert on message length moved to encoder
|
||||
"""
|
||||
import argparse, sys, numpy as np, subprocess, pathlib, os, tempfile
|
||||
import argparse
|
||||
import numpy as np
|
||||
import encoder, decoder, channel
|
||||
import subprocess, pathlib, os, tempfile, sys
|
||||
|
||||
INPUT_FILE = OUTPUT_FILE = None
|
||||
|
||||
def transmit(msg, C): return encoder.encode_message(msg, C)
|
||||
def receive_local(x): return channel.channel(x)
|
||||
def transmit(msg, C):
|
||||
return encoder.encode_message(msg, C)
|
||||
|
||||
def receive_server(x, host, port):
|
||||
global INPUT_FILE, OUTPUT_FILE
|
||||
if INPUT_FILE:
|
||||
in_f, rm_in = INPUT_FILE, False
|
||||
np.savetxt(in_f, x)
|
||||
else:
|
||||
tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
|
||||
np.savetxt(tf.name, x); tf.close()
|
||||
in_f, rm_in = tf.name, True
|
||||
if OUTPUT_FILE:
|
||||
out_f, rm_out = OUTPUT_FILE, False
|
||||
else:
|
||||
fd, out_f = tempfile.mkstemp(suffix='.txt'); os.close(fd); rm_out = True
|
||||
def receive_local(x):
|
||||
return channel.channel(x)
|
||||
|
||||
cmd = [sys.executable, str(pathlib.Path(__file__).parent/'client.py'),
|
||||
'--input_file', in_f, '--output_file', out_f,
|
||||
'--srv_hostname', host, '--srv_port', str(port)]
|
||||
def receive_server(x, host, port, input_file=None, output_file=None):
|
||||
in_f = input_file or tempfile.NamedTemporaryFile(suffix='.txt', delete=False).name
|
||||
np.savetxt(in_f, x)
|
||||
|
||||
out_f = output_file or tempfile.mkstemp(suffix='.txt')[1]
|
||||
|
||||
cmd = [
|
||||
sys.executable, str(pathlib.Path(__file__).parent / 'client.py'),
|
||||
'--input_file', in_f, '--output_file', out_f,
|
||||
'--srv_hostname', host, '--srv_port', str(port)
|
||||
]
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
Y = np.loadtxt(out_f)
|
||||
finally:
|
||||
if rm_in and os.path.exists(in_f): os.remove(in_f)
|
||||
if rm_out and os.path.exists(out_f): os.remove(out_f)
|
||||
if not input_file and os.path.exists(in_f):
|
||||
os.remove(in_f)
|
||||
if not output_file and os.path.exists(out_f):
|
||||
os.remove(out_f)
|
||||
return Y
|
||||
|
||||
def receive(x, mode, host, port):
|
||||
return receive_local(x) if mode=='local' else receive_server(x,host,port)
|
||||
|
||||
def test(msg, n_trials, mode, host, port):
|
||||
C, _ = encoder.make_codebook() # r=11 by default
|
||||
def receive(x, mode, host, port, input_file=None, output_file=None):
|
||||
if mode == 'local':
|
||||
return receive_local(x)
|
||||
return receive_server(x, host, port, input_file, output_file)
|
||||
|
||||
|
||||
def test(msg, n_trials, mode, host, port, input_file=None, output_file=None):
|
||||
C, _ = encoder.make_codebook()
|
||||
print(f"Using codebook with {C.shape[0]} codewords, {C.shape[1]} symbols each.")
|
||||
ok = 0
|
||||
for _ in range(n_trials):
|
||||
x = transmit(msg, C)
|
||||
print(f"Transmitted {len(x):,} samples (energy={np.dot(x,x):.2f})")
|
||||
y = receive(x, mode, host, port)
|
||||
print(f"Received {len(y):,} samples (energy={np.dot(y,y):.2f})")
|
||||
x = transmit(msg, C)
|
||||
print(f"Transmitted {len(x):,} samples (energy={np.dot(x, x):.2f})")
|
||||
y = receive(x, mode, host, port, input_file, output_file)
|
||||
print(f"Received {len(y):,} samples (energy={np.dot(y, y):.2f})")
|
||||
est, _ = decoder.decode_blocks_with_state(y, C)
|
||||
if est == msg: ok += 1
|
||||
print(f"{ok}/{n_trials} perfect decodes ({100*ok/n_trials:.2f}%)")
|
||||
if est == msg:
|
||||
ok += 1
|
||||
print(f"{ok}/{n_trials} perfect decodes ({100 * ok / n_trials:.2f}%)")
|
||||
|
||||
def _args():
|
||||
p=argparse.ArgumentParser()
|
||||
p.add_argument('-m','--message',required=True,help='exactly 40 chars')
|
||||
p.add_argument('-n','--trials', type=int, default=200)
|
||||
p.add_argument('--mode',choices=['local','server'],default='local')
|
||||
p.add_argument('--hostname',default='iscsrv72.epfl.ch'); p.add_argument('--port',type=int,default=80)
|
||||
p.add_argument('-i','--input_file'); p.add_argument('-o','--output_file')
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument('-m', '--message', required=True, help='exactly 40 chars')
|
||||
p.add_argument('-n', '--trials', type=int, default=200)
|
||||
p.add_argument('--mode', choices=['local', 'server'], default='local')
|
||||
p.add_argument('--hostname', default='iscsrv72.epfl.ch')
|
||||
p.add_argument('--port', type=int, default=80)
|
||||
p.add_argument('-i', '--input_file')
|
||||
p.add_argument('-o', '--output_file')
|
||||
return p.parse_args()
|
||||
|
||||
if __name__=='__main__':
|
||||
a=_args(); INPUT_FILE=a.input_file; OUTPUT_FILE=a.output_file
|
||||
test(a.message, a.trials, a.mode, a.hostname, a.port)
|
||||
|
||||
if __name__ == '__main__':
|
||||
a = parse_args()
|
||||
test(a.message, a.trials, a.mode, a.hostname, a.port, a.input_file, a.output_file)
|
||||
|
|
119
main_backup.py
119
main_backup.py
|
@ -1,119 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import numpy as np
|
||||
import encoder
|
||||
import decoder
|
||||
import channel
|
||||
import subprocess
|
||||
import pathlib
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
# Global paths for debugging
|
||||
INPUT_FILE = None
|
||||
OUTPUT_FILE = None
|
||||
|
||||
def transmit(msg, C):
|
||||
return encoder.encode_message(msg, C)
|
||||
|
||||
def receive_local(c):
|
||||
return channel.channel(c)
|
||||
|
||||
def receive_server(c, hostname, port):
|
||||
global INPUT_FILE, OUTPUT_FILE
|
||||
# write or temp-file for input
|
||||
if INPUT_FILE:
|
||||
in_name, delete_in = INPUT_FILE, False
|
||||
np.savetxt(in_name, c)
|
||||
else:
|
||||
tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
|
||||
np.savetxt(tf.name, c)
|
||||
in_name, delete_in = tf.name, True
|
||||
tf.close()
|
||||
# write or temp-file for output
|
||||
if OUTPUT_FILE:
|
||||
out_name, delete_out = OUTPUT_FILE, False
|
||||
else:
|
||||
fd, out_name = tempfile.mkstemp(suffix='.txt')
|
||||
os.close(fd)
|
||||
delete_out = True
|
||||
|
||||
# call client.py
|
||||
cmd = [
|
||||
sys.executable,
|
||||
str(pathlib.Path(__file__).parent / 'client.py'),
|
||||
'--input_file', in_name,
|
||||
'--output_file', out_name,
|
||||
'--srv_hostname', hostname,
|
||||
'--srv_port', str(port)
|
||||
]
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
Y = np.loadtxt(out_name)
|
||||
finally:
|
||||
if delete_in and os.path.exists(in_name): os.remove(in_name)
|
||||
if delete_out and os.path.exists(out_name): os.remove(out_name)
|
||||
|
||||
return Y
|
||||
|
||||
def receive(c, mode, hostname, port):
|
||||
if mode == 'local':
|
||||
return receive_local(c)
|
||||
elif mode == 'server':
|
||||
return receive_server(c, hostname, port)
|
||||
else:
|
||||
raise ValueError("mode must be 'local' or 'server'")
|
||||
|
||||
def test_performance(msg, num_trials, mode, hostname, port):
|
||||
if len(msg) != 40:
|
||||
raise ValueError('Message must be exactly 40 characters.')
|
||||
|
||||
# build the same codebook
|
||||
C = encoder.make_codebook(r=11, num_blocks=40)
|
||||
|
||||
successes = 0
|
||||
print(f"Original message: {msg!r}")
|
||||
print(f"Running {num_trials} trials over '{mode}' channel...\n")
|
||||
|
||||
for i in range(num_trials):
|
||||
# TX → channel → RX
|
||||
c = transmit(msg, C)
|
||||
Y = receive(c, mode, hostname, port)
|
||||
|
||||
# decode with state‐estimation
|
||||
est, s_est = decoder.decode_blocks_with_state(Y, C)
|
||||
|
||||
# count char errors
|
||||
errors = sum(1 for a,b in zip(est, msg) if a!=b)
|
||||
if errors == 0:
|
||||
successes += 1
|
||||
|
||||
print(f"Trial {i+1:3d}: state={s_est} decoded={est!r} errors={errors}")
|
||||
|
||||
pct = 100 * successes / num_trials
|
||||
print("\n=== Summary ===")
|
||||
print(f" Total trials: {num_trials}")
|
||||
print(f" Perfect decodings: {successes}")
|
||||
print(f" Overall accuracy: {pct:.2f}%")
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(description='Test comms system')
|
||||
p.add_argument('--message','-m', required=True, help='40-char message')
|
||||
p.add_argument('--trials','-n', type=int, default=200, help='Number of trials')
|
||||
p.add_argument('--mode', choices=['local','server'], default='local')
|
||||
p.add_argument('--hostname', default='iscsrv72.epfl.ch')
|
||||
p.add_argument('--port', type=int, default=80)
|
||||
p.add_argument('--input_file','-i', default=None,
|
||||
help='(server mode) where to write input.txt')
|
||||
p.add_argument('--output_file','-o',default=None,
|
||||
help='(server mode) where to write output.txt')
|
||||
return p.parse_args()
|
||||
|
||||
if __name__=='__main__':
|
||||
args = parse_args()
|
||||
INPUT_FILE = args.input_file
|
||||
OUTPUT_FILE = args.output_file
|
||||
test_performance(args.message, args.trials,
|
||||
args.mode, args.hostname, args.port)
|
Loading…
Add table
Reference in a new issue