feat: improve decoder.py
This commit is contained in:
parent
177bc566cf
commit
8b876c0d06
2 changed files with 83 additions and 77 deletions
39
decoder.py
39
decoder.py
|
@ -5,7 +5,7 @@ from encoder import ALPHABET, G
|
||||||
|
|
||||||
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
|
def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
|
||||||
"""
|
"""
|
||||||
Decode received samples by maximum correlation score
|
Decode received samples by maximum correlation score (state unknown, per-block).
|
||||||
"""
|
"""
|
||||||
n = C.shape[1]
|
n = C.shape[1]
|
||||||
half = n // 2
|
half = n // 2
|
||||||
|
@ -15,7 +15,7 @@ def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
|
||||||
sqrtG = np.sqrt(G)
|
sqrtG = np.sqrt(G)
|
||||||
recovered = []
|
recovered = []
|
||||||
for k in range(num):
|
for k in range(num):
|
||||||
Yb = Y[k*n:(k+1)*n]
|
Yb = Y[k * n:(k + 1) * n]
|
||||||
Ye, Yo = Yb[0::2], Yb[1::2]
|
Ye, Yo = Yb[0::2], Yb[1::2]
|
||||||
s1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
|
s1 = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
|
||||||
s2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
|
s2 = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
|
||||||
|
@ -25,8 +25,37 @@ def decode_blocks(Y: np.ndarray, C: np.ndarray) -> str:
|
||||||
return "".join(recovered)
|
return "".join(recovered)
|
||||||
|
|
||||||
|
|
||||||
def count_errors(orig: str, est: str):
|
def decode_blocks_with_state(Y: np.ndarray, C: np.ndarray) -> (str, int):
|
||||||
"""
|
"""
|
||||||
List mismatches between orig and est
|
1) Estimate the single channel state s∈{1,2} by comparing total energy
|
||||||
|
on even vs odd positions across the entire Y.
|
||||||
|
2) Decode **all** blocks using that one state’s scoring rule.
|
||||||
|
|
||||||
|
Returns (decoded_string, estimated_state).
|
||||||
"""
|
"""
|
||||||
return [(i, o, e) for i, (o, e) in enumerate(zip(orig, est)) if o != e]
|
n = C.shape[1]
|
||||||
|
half = n // 2
|
||||||
|
num = Y.size // n
|
||||||
|
C1 = C[:, :half]
|
||||||
|
C2 = C[:, half:]
|
||||||
|
sqrtG = np.sqrt(G)
|
||||||
|
|
||||||
|
# 1) state estimate from full-length energies
|
||||||
|
Ye_all = Y[0::2]
|
||||||
|
Yo_all = Y[1::2]
|
||||||
|
E_even = np.sum(Ye_all ** 2)
|
||||||
|
E_odd = np.sum(Yo_all ** 2)
|
||||||
|
s_est = 1 if E_even > E_odd else 2
|
||||||
|
|
||||||
|
recovered = []
|
||||||
|
for k in range(num):
|
||||||
|
Yb = Y[k * n:(k + 1) * n]
|
||||||
|
Ye, Yo = Yb[0::2], Yb[1::2]
|
||||||
|
if s_est == 1:
|
||||||
|
score = sqrtG * (Ye @ C1.T) + (Yo @ C2.T)
|
||||||
|
else:
|
||||||
|
score = (Ye @ C1.T) + sqrtG * (Yo @ C2.T)
|
||||||
|
best = int(np.argmax(score))
|
||||||
|
recovered.append(ALPHABET[best])
|
||||||
|
|
||||||
|
return "".join(recovered), s_est
|
||||||
|
|
121
main.py
121
main.py
|
@ -16,127 +16,104 @@ INPUT_FILE = None
|
||||||
OUTPUT_FILE = None
|
OUTPUT_FILE = None
|
||||||
|
|
||||||
def transmit(msg, C):
|
def transmit(msg, C):
|
||||||
'''
|
|
||||||
Transmitter: encodes the message into real-valued samples using the codebook C.
|
|
||||||
'''
|
|
||||||
return encoder.encode_message(msg, C)
|
return encoder.encode_message(msg, C)
|
||||||
|
|
||||||
def receive_local(c):
|
def receive_local(c):
|
||||||
'''
|
|
||||||
Sends the samples through the local channel simulation.
|
|
||||||
'''
|
|
||||||
return channel.channel(c)
|
return channel.channel(c)
|
||||||
|
|
||||||
def receive_server(c, hostname, port):
|
def receive_server(c, hostname, port):
|
||||||
'''
|
|
||||||
Sends the samples to the remote server via client.py. If INPUT_FILE and/or
|
|
||||||
OUTPUT_FILE are set, uses those filenames (and preserves them); otherwise uses temporary files.
|
|
||||||
'''
|
|
||||||
global INPUT_FILE, OUTPUT_FILE
|
global INPUT_FILE, OUTPUT_FILE
|
||||||
# Determine input file path
|
# write or temp-file for input
|
||||||
if INPUT_FILE:
|
if INPUT_FILE:
|
||||||
in_name = INPUT_FILE
|
in_name, delete_in = INPUT_FILE, False
|
||||||
np.savetxt(in_name, c)
|
np.savetxt(in_name, c)
|
||||||
delete_in = False
|
|
||||||
else:
|
else:
|
||||||
with tempfile.NamedTemporaryFile(suffix='.txt', delete=False) as in_f:
|
tf = tempfile.NamedTemporaryFile(suffix='.txt', delete=False)
|
||||||
np.savetxt(in_f.name, c)
|
np.savetxt(tf.name, c)
|
||||||
in_name = in_f.name
|
in_name, delete_in = tf.name, True
|
||||||
delete_in = True
|
tf.close()
|
||||||
# Determine output file path
|
# write or temp-file for output
|
||||||
if OUTPUT_FILE:
|
if OUTPUT_FILE:
|
||||||
out_name = OUTPUT_FILE
|
out_name, delete_out = OUTPUT_FILE, False
|
||||||
delete_out = False
|
|
||||||
else:
|
else:
|
||||||
out_fd, out_name = tempfile.mkstemp(suffix='.txt')
|
fd, out_name = tempfile.mkstemp(suffix='.txt')
|
||||||
os.close(out_fd)
|
os.close(fd)
|
||||||
delete_out = True
|
delete_out = True
|
||||||
# Invoke client.py
|
|
||||||
|
# call client.py
|
||||||
cmd = [
|
cmd = [
|
||||||
sys.executable,
|
sys.executable,
|
||||||
str(pathlib.Path(__file__).parent / 'client.py'),
|
str(pathlib.Path(__file__).parent / 'client.py'),
|
||||||
'--input_file', in_name,
|
'--input_file', in_name,
|
||||||
'--output_file', out_name,
|
'--output_file', out_name,
|
||||||
'--srv_hostname', hostname,
|
'--srv_hostname', hostname,
|
||||||
'--srv_port', str(port)
|
'--srv_port', str(port)
|
||||||
]
|
]
|
||||||
try:
|
try:
|
||||||
subprocess.run(cmd, check=True)
|
subprocess.run(cmd, check=True)
|
||||||
Y = np.loadtxt(out_name)
|
Y = np.loadtxt(out_name)
|
||||||
finally:
|
finally:
|
||||||
if delete_in and os.path.exists(in_name):
|
if delete_in and os.path.exists(in_name): os.remove(in_name)
|
||||||
os.remove(in_name)
|
if delete_out and os.path.exists(out_name): os.remove(out_name)
|
||||||
if delete_out and os.path.exists(out_name):
|
|
||||||
os.remove(out_name)
|
|
||||||
return Y
|
return Y
|
||||||
|
|
||||||
def receive(c, mode, hostname, port):
|
def receive(c, mode, hostname, port):
|
||||||
'''
|
|
||||||
Wrapper to choose local or server channel.
|
|
||||||
'''
|
|
||||||
if mode == 'local':
|
if mode == 'local':
|
||||||
return receive_local(c)
|
return receive_local(c)
|
||||||
elif mode == 'server':
|
elif mode == 'server':
|
||||||
return receive_server(c, hostname, port)
|
return receive_server(c, hostname, port)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Mode must be 'local' or 'server'")
|
raise ValueError("mode must be 'local' or 'server'")
|
||||||
|
|
||||||
def test_performance(msg, num_trials, mode, hostname, port):
|
def test_performance(msg, num_trials, mode, hostname, port):
|
||||||
'''
|
|
||||||
Runs num_trials transmissions of msg through the specified channel and reports
|
|
||||||
per-trial decoded messages and error counts, plus overall accuracy.
|
|
||||||
'''
|
|
||||||
if len(msg) != 40:
|
if len(msg) != 40:
|
||||||
raise ValueError('Message must be exactly 40 characters.')
|
raise ValueError('Message must be exactly 40 characters.')
|
||||||
# Build codebook for 64 symbols, 40 blocks
|
|
||||||
|
# build the same codebook
|
||||||
C = encoder.make_codebook(r=5, num_blocks=40)
|
C = encoder.make_codebook(r=5, num_blocks=40)
|
||||||
|
|
||||||
successes = 0
|
successes = 0
|
||||||
|
print(f"Original message: {msg!r}")
|
||||||
print(f"Original message: {msg}")
|
|
||||||
print(f"Running {num_trials} trials over '{mode}' channel...\n")
|
print(f"Running {num_trials} trials over '{mode}' channel...\n")
|
||||||
|
|
||||||
for i in range(num_trials):
|
for i in range(num_trials):
|
||||||
# Transmit
|
# TX → channel → RX
|
||||||
c = transmit(msg, C)
|
c = transmit(msg, C)
|
||||||
# Channel
|
|
||||||
Y = receive(c, mode, hostname, port)
|
Y = receive(c, mode, hostname, port)
|
||||||
# Decode
|
|
||||||
est = decoder.decode_blocks(Y, C)
|
# decode with state‐estimation
|
||||||
# Count character errors
|
est, s_est = decoder.decode_blocks_with_state(Y, C)
|
||||||
errors = sum(1 for a, b in zip(est, msg) if a != b)
|
|
||||||
# Tally success if no errors
|
# count char errors
|
||||||
|
errors = sum(1 for a,b in zip(est, msg) if a!=b)
|
||||||
if errors == 0:
|
if errors == 0:
|
||||||
successes += 1
|
successes += 1
|
||||||
# Print trial result
|
|
||||||
print(f"Trial {i+1:3d}: Decoded: '{est}' | Errors: {errors}")
|
|
||||||
|
|
||||||
pct = successes / num_trials * 100
|
print(f"Trial {i+1:3d}: state={s_est} decoded={est!r} errors={errors}")
|
||||||
|
|
||||||
|
pct = 100 * successes / num_trials
|
||||||
print("\n=== Summary ===")
|
print("\n=== Summary ===")
|
||||||
print(f"Total trials: {num_trials}")
|
print(f" Total trials: {num_trials}")
|
||||||
print(f"Perfect decodings: {successes}")
|
print(f" Perfect decodings: {successes}")
|
||||||
print(f"Overall accuracy: {pct:.2f}%")
|
print(f" Overall accuracy: {pct:.2f}%")
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Test communication system performance.')
|
p = argparse.ArgumentParser(description='Test comms system')
|
||||||
parser.add_argument('--message', '-m', type=str, required=True,
|
p.add_argument('--message','-m', required=True, help='40-char message')
|
||||||
help='40-character message to send.')
|
p.add_argument('--trials','-n', type=int, default=200, help='Number of trials')
|
||||||
parser.add_argument('--trials', '-n', type=int, default=200,
|
p.add_argument('--mode', choices=['local','server'], default='local')
|
||||||
help='Number of trials.')
|
p.add_argument('--hostname', default='iscsrv72.epfl.ch')
|
||||||
parser.add_argument('--mode', choices=['local','server'], default='local',
|
p.add_argument('--port', type=int, default=80)
|
||||||
help="Channel mode: 'local' or 'server'.")
|
p.add_argument('--input_file','-i', default=None,
|
||||||
parser.add_argument('--hostname', type=str, default='iscsrv72.epfl.ch',
|
help='(server mode) where to write input.txt')
|
||||||
help='Server hostname for server mode.')
|
p.add_argument('--output_file','-o',default=None,
|
||||||
parser.add_argument('--port', type=int, default=80,
|
help='(server mode) where to write output.txt')
|
||||||
help='Server port for server mode.')
|
return p.parse_args()
|
||||||
parser.add_argument('--input_file', '-i', type=str, default=None,
|
|
||||||
help='Path to write server input samples (input.txt).')
|
|
||||||
parser.add_argument('--output_file', '-o', type=str, default=None,
|
|
||||||
help='Path to write server output samples (output.txt).')
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__=='__main__':
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
INPUT_FILE = args.input_file
|
INPUT_FILE = args.input_file
|
||||||
OUTPUT_FILE = args.output_file
|
OUTPUT_FILE = args.output_file
|
||||||
test_performance(args.message, args.trials, args.mode, args.hostname, args.port)
|
test_performance(args.message, args.trials,
|
||||||
|
args.mode, args.hostname, args.port)
|
||||||
|
|
Loading…
Add table
Reference in a new issue