Undid parallelization changes, fairly sure it degraded perf

This reverts commit 5a296819aedfcfd9a318df524d5ba3a5091f3a69.
This commit is contained in:
Miguel M 2023-02-23 18:29:05 +00:00
parent deb080d29d
commit 2dd7160fba
3 changed files with 143 additions and 340 deletions

View File

@ -1,24 +1,15 @@
extern crate parking_lot;
extern crate pest;
extern crate pyo3;
#[macro_use]
extern crate pest_derive;
#[macro_use]
extern crate closure;
mod vm;
use parking_lot::Mutex;
use pyo3::prelude::*;
use vm::VmCode;
// We cache the A and B set up to a total of ~4Gb of memory,
// meaning ~ 6×10^7 elements in each set
const CACHE_SIZE_LIMIT: usize = 60_000_000;
// OS threads, so it's fine if it goes over the number of cores.
const THREADS: u8 = 8;
// We cache the A and B set up to a total of ~2Gb of memory,
// meaning ~ 10^7 elements in each set
const CACHE_SIZE_LIMIT: usize = 10_000_000;
#[pyclass]
struct BoundsResult {
@ -47,203 +38,108 @@ struct Prover {
relationship: VmCode,
}
struct CachedSets {
filter: VmCode,
n: u8,
p: u8,
k: u8,
is_x: bool,
}
#[derive(Debug)]
struct SplitCachedSetIterator {
filter: VmCode,
n: u8,
p: u8,
k: u8,
is_x: bool,
struct CachedSetIterator<'code> {
cache: Vec<u64>,
cache_size: usize,
filter: &'code VmCode,
n: u8,
p: u8,
k: u8,
cached: bool,
window_start: usize,
window_size: usize,
counter: usize,
counter: u64,
cache_counter: usize,
top: u64,
is_x: bool,
}
/// A spinlock on a vector of Mutexes.
///
/// Iterating over this object yields mutex guards in the order that they could
/// be acquired.
struct AsAreFree<'vec, T> {
inner: &'vec Vec<Mutex<T>>,
remaining: Vec<bool>,
remaining_count: usize,
}
impl<'vec, T> AsAreFree<'vec, T> {
fn new(vec: &'vec Vec<Mutex<T>>) -> Self {
let inner = vec;
let remaining = (0..vec.len()).map(|_| true).collect();
let remaining_count = vec.len();
AsAreFree {
inner,
remaining,
remaining_count,
}
}
}
impl<'vec, T> Iterator for AsAreFree<'vec, T> {
type Item = parking_lot::lock_api::MutexGuard<'vec, parking_lot::RawMutex, T>;
fn next(&mut self) -> Option<Self::Item> {
if self.remaining_count == 0 {
None
} else {
let len = self.remaining.len();
loop {
for i in 0..len {
if self.remaining[i] {
if let Some(guard) = self.inner[i].try_lock() {
self.remaining_count -= 1;
self.remaining[i] = false;
return Some(guard);
}
}
}
}
}
}
}
impl CachedSets {
fn create_x(filter: VmCode, n: u8, p: u8, k: u8) -> Self {
CachedSets {
impl<'code> CachedSetIterator<'code> {
fn create_x(filter: &'code VmCode, n: u8, p: u8, k: u8) -> Self {
CachedSetIterator {
cache: Vec::with_capacity(CACHE_SIZE_LIMIT),
filter,
n,
p,
k,
cached: false,
counter: 0,
cache_counter: 0,
top: 2_u64.pow(n as u32),
is_x: true,
}
}
fn create_y(filter: VmCode, n: u8, p: u8, k: u8) -> Self {
CachedSets {
fn create_y(filter: &'code VmCode, n: u8, p: u8, k: u8) -> Self {
CachedSetIterator {
cache: Vec::with_capacity(CACHE_SIZE_LIMIT),
filter,
n,
p,
k,
cached: false,
counter: 0,
cache_counter: 0,
top: 2_u64.pow(n as u32),
is_x: false,
}
}
fn split_into(self, threads: u8) -> Vec<SplitCachedSetIterator> {
let top = (1 << self.n) - 1;
let window_size = top / threads as usize;
let remainder = top % threads as usize;
let mut split_iterators = vec![];
for iter_index in 0..threads {
let window_start = iter_index as usize * window_size;
let window_size = if iter_index < threads - 1 {
window_size
} else {
window_size + remainder
};
let cache_size = if iter_index < threads - 1 {
CACHE_SIZE_LIMIT / threads as usize
} else {
CACHE_SIZE_LIMIT / threads as usize + CACHE_SIZE_LIMIT % threads as usize
};
let cache = Vec::<u64>::with_capacity(cache_size);
let split_iter = SplitCachedSetIterator {
filter: self.filter.clone(),
n: self.n,
p: self.p,
k: self.k,
is_x: self.is_x,
cache,
cache_size,
window_start,
window_size,
cached: false,
counter: window_start,
};
split_iterators.push(split_iter);
}
split_iterators
}
}
impl SplitCachedSetIterator {
fn reset(&mut self) {
if self.cached {
self.counter = 0;
} else {
self.counter = self.window_start;
}
self.counter = 0;
self.cache_counter = 0;
}
}
impl Iterator for &mut SplitCachedSetIterator {
impl<'code> Iterator for &mut CachedSetIterator<'code> {
type Item = u64;
fn next(&mut self) -> Option<Self::Item> {
loop {
if self.cached && self.counter < self.cache.len() {
let cache = &self.cache;
if self.counter < cache.len() {
let result = cache[self.counter];
self.counter += 1;
if self.cached {
if self.cache_counter < self.cache.len() {
let result = self.cache[self.cache_counter];
self.cache_counter += 1;
if self.counter == cache.len() {
self.counter = *cache.last().unwrap() as usize + 1;
}
return Some(result);
if self.cache_counter == CACHE_SIZE_LIMIT {
self.counter = self.cache.last().unwrap() + 1;
}
}
if self.counter == self.window_start + self.window_size {
self.cached = true;
return Some(result);
} else if self.cache.len() < CACHE_SIZE_LIMIT {
self.reset();
return None;
}
let included = if self.is_x {
vm::Vm::load(
&self.filter,
vm::Registers::load(self.counter as u64, 0, self.n, self.p, self.k),
)
.run()
.unwrap_bool()
} else {
vm::Vm::load(
&self.filter,
vm::Registers::load(0, self.counter as u64, self.n, self.p, self.k),
)
.run()
.unwrap_bool()
};
if self.counter == self.top {
self.cached = true;
self.reset();
return None;
}
if !included {
self.counter += 1;
continue;
}
let result = self.counter as u64;
if self.cache.len() < self.cache_size {
self.cache.push(result);
}
let included = if self.is_x {
vm::Vm::load(
self.filter,
vm::Registers::load(self.counter, 0, self.n, self.p, self.k),
)
.run()
.unwrap_bool()
} else {
vm::Vm::load(
self.filter,
vm::Registers::load(0, self.counter, self.n, self.p, self.k),
)
.run()
.unwrap_bool()
};
self.counter += 1;
return Some(result);
}
let result = self.counter;
if self.cache.len() < CACHE_SIZE_LIMIT {
self.cache.push(result);
}
self.counter += 1;
Some(result)
}
}
@ -287,11 +183,11 @@ impl Iterator for FixedHammingWeight {
#[pymethods]
impl Prover {
/// Initializes a searcher for lower bounds.
///
///
/// This object prepares the search for lower bounds according to Ambainis et al.'s
/// adversary method. The bounds themselves for specific values of the parameters can
/// be found with the associated method `find_bounds`.
///
///
/// Arguments:
/// a_description A description of elements that should belong to the A set,
/// as specified discriminating on x [string]
@ -367,7 +263,7 @@ impl Prover {
}
/// Finds a lower bound to the query complexity according to the specified parameters.
///
///
/// Arguments:
/// n The maximum number of bits to consider when exhausting bitstrings
/// belonging to either of the specified sets [int]
@ -390,24 +286,15 @@ impl Prover {
));
}
let a_set_iterators = CachedSets::create_x(self.a_description.clone(), n, p, k)
.split_into(THREADS)
.into_iter()
.map(|iter| Mutex::new(iter))
.collect::<Vec<Mutex<SplitCachedSetIterator>>>();
let b_set_iterators = CachedSets::create_y(self.b_description.clone(), n, p, k)
.split_into(THREADS)
.into_iter()
.map(|iter| Mutex::new(iter))
.collect::<Vec<Mutex<SplitCachedSetIterator>>>();
let mut a_set_iterator = CachedSetIterator::create_x(&self.a_description, n, p, k);
let mut b_set_iterator = CachedSetIterator::create_y(&self.b_description, n, p, k);
let window_iterator = FixedHammingWeight::new(n, p);
let min_x_relations = Mutex::new(u64::MAX);
let min_y_relations = Mutex::new(u64::MAX);
let mut min_x_relations = u64::MAX;
let mut min_y_relations = u64::MAX;
let mut max_joint_relations = 0_u128;
let single_bounding_x = Mutex::new(0_u64);
let single_bounding_y = Mutex::new(0_u64);
let mut single_bounding_x = 0_u64;
let mut single_bounding_y = 0_u64;
let mut joint_bounding_x = 0_u64;
let mut joint_bounding_y = 0_u64;
let mut bounding_window = 0_u64;
@ -419,111 +306,60 @@ impl Prover {
};
for window in window_iterator {
let max_x_relations = Mutex::new(0_u64);
let max_y_relations = Mutex::new(0_u64);
let bounding_x_candidate = Mutex::new(0_u64);
let bounding_y_candidate = Mutex::new(0_u64);
let mut max_x_relations = 0_u64;
let mut max_y_relations = 0_u64;
let mut bounding_x_candidate = 0_u64;
let mut bounding_y_candidate = 0_u64;
std::thread::scope(|s| {
for thread_index in 0..THREADS {
s.spawn(closure!(
move thread_index,
ref a_set_iterators,
ref b_set_iterators,
ref min_x_relations,
ref single_bounding_x,
ref max_x_relations,
ref bounding_x_candidate, || {
let a_subset_iterator = &a_set_iterators[thread_index as usize];
let mut a_subset_iterator = a_subset_iterator.lock();
for x in &mut *a_subset_iterator {
// How many relations this x has
let mut min_x_relations_candidate = 0_u64;
// How many relations this x has with y such that (x & window) != (y & window)
let mut max_x_relations_candidate = 0_u64;
for x in &mut a_set_iterator {
let mut min_x_relations_candidate = 0_u64;
let mut max_x_relations_candidate = 0_u64;
for mut b_subset_iterator in AsAreFree::new(b_set_iterators) {
for y in &mut *b_subset_iterator {
if related(x, y) {
min_x_relations_candidate += 1;
if (x & window) != (y & window) {
max_x_relations_candidate += 1;
}
}
}
}
{
let min_x_relations = &mut *min_x_relations.lock();
if min_x_relations_candidate < *min_x_relations {
*min_x_relations = min_x_relations_candidate;
*(&mut *single_bounding_x.lock()) = x;
}
}
{
let max_x_relations = &mut *max_x_relations.lock();
if max_x_relations_candidate > *max_x_relations {
*max_x_relations = max_x_relations_candidate;
*(&mut *bounding_x_candidate.lock()) = x;
}
}
for y in &mut b_set_iterator {
if related(x, y) {
min_x_relations_candidate += 1;
if (x & window) != (y & window) {
max_x_relations_candidate += 1;
}
}));
s.spawn(closure!(
move thread_index,
ref a_set_iterators,
ref b_set_iterators,
ref min_y_relations,
ref single_bounding_y,
ref max_y_relations,
ref bounding_y_candidate, || {
let b_subset_iterator = &b_set_iterators[thread_index as usize];
let mut b_subset_iterator = b_subset_iterator.lock();
for y in &mut *b_subset_iterator {
// How many relations this y has
let mut min_y_relations_candidate = 0_u64;
// How many relations this y has with x such that (x & window) != (y & window)
let mut max_y_relations_candidate = 0_u64;
for mut a_subset_iterator in AsAreFree::new(a_set_iterators) {
for x in &mut *a_subset_iterator {
if related(x, y) {
min_y_relations_candidate += 1;
if (x & window) != (y & window) {
max_y_relations_candidate += 1;
}
}
}
}
{
let min_y_relations = &mut *min_y_relations.lock();
if min_y_relations_candidate < *min_y_relations {
*min_y_relations = min_y_relations_candidate;
*(&mut *single_bounding_y.lock()) = y;
}
}
{
let max_y_relations = &mut *max_y_relations.lock();
if max_y_relations_candidate > *max_y_relations {
*max_y_relations = max_y_relations_candidate;
*(&mut *bounding_y_candidate.lock()) = y;
}
}
}
}));
}
}
});
let max_x_relations = *max_x_relations.lock();
let max_y_relations = *max_y_relations.lock();
let bounding_x_candidate = *bounding_x_candidate.lock();
let bounding_y_candidate = *bounding_y_candidate.lock();
if min_x_relations_candidate < min_x_relations {
min_x_relations = min_x_relations_candidate;
single_bounding_x = x;
}
if max_x_relations_candidate > max_x_relations {
max_x_relations = max_x_relations_candidate;
bounding_x_candidate = x;
}
}
for y in &mut b_set_iterator {
let mut min_y_relations_candidate = 0_u64;
let mut max_y_relations_candidate = 0_u64;
for x in &mut a_set_iterator {
if related(x, y) {
min_y_relations_candidate += 1;
if (x & window) != (y & window) {
max_y_relations_candidate += 1;
}
}
}
if min_y_relations_candidate < min_y_relations {
min_y_relations = min_y_relations_candidate;
single_bounding_y = y;
}
if max_y_relations_candidate > max_y_relations {
max_y_relations = max_y_relations_candidate;
bounding_y_candidate = y;
}
}
let max_joint_relations_candidate = max_x_relations as u128 * max_y_relations as u128;
if max_joint_relations_candidate > max_joint_relations {
max_joint_relations = max_joint_relations_candidate;
@ -533,11 +369,6 @@ impl Prover {
}
}
let min_x_relations = *min_x_relations.lock();
let min_y_relations = *min_y_relations.lock();
let single_bounding_x = *single_bounding_x.lock();
let single_bounding_y = *single_bounding_y.lock();
Ok(BoundsResult {
min_x_relations,
min_y_relations,
@ -560,23 +391,14 @@ fn adversary(_py: Python, m: &PyModule) -> PyResult<()> {
Ok(())
}
fn choose(n: u64, k: u64) -> u64 {
if k == 0 {
1
} else {
(n * choose(n - 1, k - 1)) / k
}
}
#[test]
fn test_single_cached_iterator() {
fn test_cached_iterator() {
let n = 10;
let k = 3;
let p = 5;
let filter =
vm::compile_boolean(vm::parsing::parse_relation("= ham x k").expect("Valid expression"));
let mut sets = CachedSets::create_x(filter, n, p, k).split_into(1);
let mut iterator = sets.first_mut().unwrap();
vm::compile_boolean(vm::parsing::parse_relation("<= ham x k").expect("Valid expression"));
let mut iterator = CachedSetIterator::create_x(&filter, n, p, k);
let first_count = iterator.count();
let second_count = iterator.count();
for x in &mut iterator {
@ -584,25 +406,6 @@ fn test_single_cached_iterator() {
assert!(x.count_ones() <= k.into());
}
assert!(first_count == second_count);
assert!(first_count == choose(n.into(), k.into()) as usize);
}
#[test]
fn test_multiple_cached_iterator() {
let n = 20;
let k = 3;
let p = 5;
let filter =
vm::compile_boolean(vm::parsing::parse_relation("= ham x k").expect("Valid expression"));
let mut sets = CachedSets::create_x(filter, n, p, k).split_into(3);
let mut third_iter = sets.pop().unwrap();
let mut second_iter = sets.pop().unwrap();
let mut first_iter = sets.pop().unwrap();
drop(sets);
let first_total = first_iter.count() + second_iter.count() + third_iter.count();
let second_total = first_iter.count() + second_iter.count() + third_iter.count();
assert!(first_total == second_total);
assert!(first_total == choose(n.into(), k.into()) as usize);
}
#[test]

View File

@ -6,7 +6,7 @@ use crate::vm::parsing::ast::ComparisonOperator;
type Bytecode = (OpCode, usize);
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct VmCode {
code: Vec<Bytecode>,
}
@ -38,7 +38,7 @@ pub enum ArithmeticValue {
Floating(f64),
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum OpCode {
UnaryBooleanOperator(ast::UnaryBooleanOperator),
BinaryArithmeticOperator(ast::BinaryArithmeticOperator),

View File

@ -1,11 +1,11 @@
use std::fmt::Debug;
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum UnaryBooleanOperator {
Not,
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum BinaryArithmeticOperator {
Times,
Divide,
@ -15,7 +15,7 @@ pub enum BinaryArithmeticOperator {
Pow,
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum ComparisonOperator {
GreaterOrEqual,
LessOrEqual,
@ -25,21 +25,21 @@ pub enum ComparisonOperator {
Equal,
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum UnaryArithmeticOperator {
Negative,
Ham,
Sqrt,
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum BinaryBooleanOperator {
And,
Or,
Xor,
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum Variable {
X,
Y,
@ -48,54 +48,54 @@ pub enum Variable {
K,
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum BooleanExpression {
BinaryBooleanConjunction(Box<BinaryBooleanConjunction>),
UnaryBooleanConjunction(Box<UnaryBooleanConjunction>),
ComparisonConjunction(Box<ComparisonConjunction>),
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct BinaryBooleanConjunction {
pub operator: BinaryBooleanOperator,
pub left_operand: BooleanExpression,
pub right_operand: BooleanExpression,
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct UnaryBooleanConjunction {
pub operator: UnaryBooleanOperator,
pub operand: BooleanExpression,
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct ComparisonConjunction {
pub operator: ComparisonOperator,
pub left_operand: ArithmeticOperand,
pub right_operand: ArithmeticOperand,
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct BinaryArithmeticConjunction {
pub operator: BinaryArithmeticOperator,
pub left_operand: ArithmeticOperand,
pub right_operand: ArithmeticOperand,
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum ArithmeticOperand {
Literal(i64),
Expression(Box<ArithmeticExpression>),
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum ArithmeticExpression {
Variable(Variable),
UnaryArithmeticConjunction(Box<UnaryArithmeticConjunction>),
BinaryArithmeticConjunction(Box<BinaryArithmeticConjunction>),
}
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct UnaryArithmeticConjunction {
pub operator: UnaryArithmeticOperator,
pub operand: ArithmeticOperand,