package main

import (
    "fmt"
    "math/big"
    "os"
    "runtime"
    "strconv"
)

func BrotRSA(n_errors, start_pos, end_pos int, candidate, pubEx, modulus *big.Int, ch chan *big.Int) {
    var enc, privExp, wanted big.Int
    var iter_mask, rsaRes *big.Int

    rsaRes = big.NewInt(0)

    wanted.Div(modulus, big.NewInt(2))
    enc.Exp(&wanted, pubEx, modulus)

    run_pos := make([]int, n_errors)

    for i := 0; i < n_errors; i++ {
        run_pos[i] = start_pos + i
    }

    run_pos_len := n_errors - 1

    iterations := 0

    max_pos := candidate.BitLen() + n_errors

    for run_pos[0] < end_pos {
        iterations++
        iter_mask = big.NewInt(0)

        for i := 0; i < n_errors; i++ {
            pos := run_pos[i]
            if pos <= max_pos {
                iter_mask.SetBit(iter_mask, pos, 1)
            }
        }

        privExp.Xor(candidate, iter_mask)

        rsaRes.Exp(&enc, &privExp, modulus)

        if rsaRes.Cmp(&wanted) == 0 {
            ch <- &privExp
            return
        }

        run_pos[run_pos_len]++
        for i := run_pos_len; i > 0; i-- {
            if run_pos[i] > max_pos + 1 {  // So we get the case with fewer than n_error errors
                run_pos[i - 1]++
                run_pos[i] = run_pos[i - 1] + 1
            }            
        }
    }

    ch <- nil 
}

func LaunchWorkers(num_workers, n_errors int, candidate, pubEx, modulus *big.Int) *big.Int {
    split := candidate.BitLen() / num_workers
    add_work := n_errors + candidate.BitLen() % num_workers

    var start_pos, end_pos int
    prev_add_work := 0

    ch := make(chan *big.Int)

    for i := 0; i < num_workers; i++ {
        start_pos = i * split + prev_add_work
        end_pos = start_pos + split
        if i == num_workers - 1 {
            end_pos = end_pos + n_errors
        }
        if add_work > 0 {
            add_work--
            end_pos++
            prev_add_work++
        }
        go BrotRSA(n_errors, start_pos, end_pos, candidate, pubEx, modulus, ch)
    }

    var res *big.Int

    for i := 0; i < num_workers; i++ {
        res = <- ch
        if res != nil {
            return res
        }    
    }
    
    return nil
}

func main() {
	var cand, pubEx, mod big.Int
    if len(os.Args) < 6 {
        fmt.Println("call: brotforce num_errors num_worker candidate pub_exponent modulus")
        fmt.Println("Num_errors: Maxmimal expected number of errors")
        fmt.Println("num_workers: Number of workers used, if 0 one worker per CPU is started")
        fmt.Println("candidate, pub_exponent, modulus:")
        fmt.Println("Numbers with prefixes are interpreted: 0x (hex) or 0b (bin)")
        fmt.Println("Numbers without prefix are intepreted as decimal")
        os.Exit(1)
    }

    cand.SetString(os.Args[3], 0)
    pubEx.SetString(os.Args[4], 0)
    mod.SetString(os.Args[5], 0)
    
    num_workers, _ := strconv.Atoi(os.Args[2])
    num_errors, _ := strconv.Atoi(os.Args[1])

    if num_workers < 1 {
        num_workers = runtime.NumCPU()
    }
    res := LaunchWorkers(num_workers, num_errors, &cand, &pubEx, &mod)

    if res != nil {
        fmt.Printf("0x%x\n", res)
    }
}