diff --git a/france/france_rsvp/codes.py b/france/france_rsvp/codes.py index fdaa0d6..16f93ca 100644 --- a/france/france_rsvp/codes.py +++ b/france/france_rsvp/codes.py @@ -11,6 +11,45 @@ assert len(ALPHABET) == 32 LENGTH = 4 +def list_errors(code): + for place in range(LENGTH): + for replacement in ALPHABET: + if replacement == code[place]: + continue + new_code = code[:place] + replacement + code[place + 1:] + yield new_code + + +correct_codes = set() +errored_codes = {} + + +def _load_codes(): + global correct_codes + global errored_codes + + # Get the current codes + with get_db() as db: + correct_codes = set(row[0] for row in db.execute( + '''\ + SELECT code FROM families; + ''', + )) + + # Augment them with possible errored codes + errored_codes = {} + for code in correct_codes: + for error in list_errors(code): + errored_codes[error] = code + assert ( + len(errored_codes) + == LENGTH * (len(ALPHABET) - 1) * len(correct_codes) + ) + + +_load_codes() + + CORRECT = { 'i': '1', 'l': '1', @@ -21,16 +60,8 @@ CORRECT = { def correct_code(code): code = code.lower() fixed_code = ''.join(CORRECT.get(c, c) for c in code) - return fixed_code - - -def list_errors(code): - for place in range(LENGTH): - for replacement in ALPHABET: - if replacement == code[place]: - continue - new_code = code[:place] + replacement + code[place + 1:] - yield new_code + fixed_code = errored_codes.get(fixed_code, None) + return fixed_code or code def main(): @@ -43,20 +74,6 @@ def main(): else: raise ValueError - # Get the current codes - with get_db() as db: - correct_codes = set(row[0] for row in db.execute( - '''\ - SELECT code FROM families; - ''', - )) - - # Augment them with possible errored codes - errored_codes = set() - for code in correct_codes: - errored_codes.update(list_errors(code)) - assert len(errored_codes) == LENGTH * (len(ALPHABET) - 1) * len(correct_codes) - # Generate new codes generated = 0 while generated < number: @@ -64,6 +81,7 @@ def main(): if code in correct_codes or code in errored_codes: continue correct_codes.add(code) - errored_codes.update(list_errors(code)) + for error in list_errors(code): + errored_codes[error] = code print(code) generated += 1