Question in short
To have a proper input for pycosat, is there a way to speed up calculation from dnf to cnf, or to circumvent it altogether?
Question in detail
I have been watching this video from Raymond Hettinger about modern solvers. I downloaded the code, and implemented a solver for the game Towers in it. Below I share the code to do so.
Example Tower puzzle (solved):
3 3 2 1
---------------
3 | 2 1 3 4 | 1
3 | 1 3 4 2 | 2
2 | 3 4 2 1 | 3
1 | 4 2 1 3 | 2
---------------
1 2 3 2
The problem I encounter is that the conversion from dnf to cnf takes forever. Let's say that you know there are 3 towers visible from a certain line of sight. This leads to 35 possible permutations 1-5 in that row.
[('AA 1', 'AB 2', 'AC 5', 'AD 3', 'AE 4'),
('AA 1', 'AB 2', 'AC 5', 'AD 4', 'AE 3'),
...
('AA 3', 'AB 4', 'AC 5', 'AD 1', 'AE 2'),
('AA 3', 'AB 4', 'AC 5', 'AD 2', 'AE 1')]
This is a disjunctive normal form: an OR of several AND statements. This needs to be converted into a conjunctive normal form: an AND of several OR statements. This is however very slow. On my Macbook Pro, it didn't finish calculating this cnf after 5 minutes for a single row. For the entire puzzle, this should be done up to 20 times (for a 5x5 grid).
What would be the best way to optimize this code, in order to make the computer able to solve this Towers puzzle?
This code is also available from this Github repository.
import string
import itertools
from sys import intern
from typing import Collection, Dict, List
from sat_utils import basic_fact, from_dnf, one_of, solve_one
Point = str
def comb(point: Point, value: int) -> str:
"""
Format a fact (a value assigned to a given point), and store it into the interned strings table
:param point: Point on the grid, characterized by two letters, e.g. AB
:param value: Value of the cell on that point, e.g. 2
:return: Fact string 'AB 2'
"""
return intern(f'{point} {value}')
def visible_from_line(line: Collection[int], reverse: bool = False) -> int:
"""
Return how many towers are visible from the given line
>>> visible_from_line([1, 2, 3, 4])
4
>>> visible_from_line([1, 4, 3, 2])
2
"""
visible = 0
highest_seen = 0
for number in reversed(line) if reverse else line:
if number > highest_seen:
visible += 1
highest_seen = number
return visible
class TowersPuzzle:
def __init__(self):
self.visible_from_top = [3, 3, 2, 1]
self.visible_from_bottom = [1, 2, 3, 2]
self.visible_from_left = [3, 3, 2, 1]
self.visible_from_right = [1, 2, 3, 2]
self.given_numbers = {'AC': 3}
# self.visible_from_top = [3, 2, 1, 4, 2]
# self.visible_from_bottom = [2, 2, 4, 1, 2]
# self.visible_from_left = [3, 2, 3, 1, 3]
# self.visible_from_right = [2, 2, 1, 3, 2]
self._cnf = None
self._solution = None
def display_puzzle(self):
print('*** Puzzle ***')
self._display(self.given_numbers)
def display_solution(self):
print('*** Solution ***')
point_to_value = {point: value for point, value in [fact.split() for fact in self.solution]}
self._display(point_to_value)
@property
def n(self) -> int:
"""
:return: Size of the grid
"""
return len(self.visible_from_top)
@property
def points(self) -> List[Point]:
return [''.join(letters) for letters in itertools.product(string.ascii_uppercase[:self.n], repeat=2)]
@property
def rows(self) -> List[List[Point]]:
"""
:return: Points, grouped per row
"""
return [self.points[i:i + self.n] for i in range(0, self.n * self.n, self.n)]
@property
def cols(self) -> List[List[Point]]:
"""
:return: Points, grouped per column
"""
return [self.points[i::self.n] for i in range(self.n)]
@property
def values(self) -> List[int]:
return list(range(1, self.n + 1))
@property
def cnf(self):
if self._cnf is None:
cnf = []
# Each point assigned exactly one value
for point in self.points:
cnf += one_of(comb(point, value) for value in self.values)
# Each value gets assigned to exactly one point in each row
for row in self.rows:
for value in self.values:
cnf += one_of(comb(point, value) for point in row)
# Each value gets assigned to exactly one point in each col
for col in self.cols:
for value in self.values:
cnf += one_of(comb(point, value) for point in col)
# Set visible from left
if self.visible_from_left:
for index, row in enumerate(self.rows):
target_visible = self.visible_from_left[index]
if not target_visible:
continue
possible_perms = []
for perm in itertools.permutations(range(1, self.n + 1)):
if visible_from_line(perm) == target_visible:
possible_perms.append(tuple(
comb(point, value)
for point, value in zip(row, perm)
))
cnf += from_dnf(possible_perms)
# Set visible from right
if self.visible_from_right:
for index, row in enumerate(self.rows):
target_visible = self.visible_from_right[index]
if not target_visible:
continue
possible_perms = []
for perm in itertools.permutations(range(1, self.n + 1)):
if visible_from_line(perm, reverse=True) == target_visible:
possible_perms.append(tuple(
comb(point, value)
for point, value in zip(row, perm)
))
cnf += from_dnf(possible_perms)
# Set visible from top
if self.visible_from_top:
for index, col in enumerate(self.cols):
target_visible = self.visible_from_top[index]
if not target_visible:
continue
possible_perms = []
for perm in itertools.permutations(range(1, self.n + 1)):
if visible_from_line(perm) == target_visible:
possible_perms.append(tuple(
comb(point, value)
for point, value in zip(col, perm)
))
cnf += from_dnf(possible_perms)
# Set visible from bottom
if self.visible_from_bottom:
for index, col in enumerate(self.cols):
target_visible = self.visible_from_bottom[index]
if not target_visible:
continue
possible_perms = []
for perm in itertools.permutations(range(1, self.n + 1)):
if visible_from_line(perm, reverse=True) == target_visible:
possible_perms.append(tuple(
comb(point, value)
for point, value in zip(col, perm)
))
cnf += from_dnf(possible_perms)
# Set given numbers
for point, value in self.given_numbers.items():
cnf += basic_fact(comb(point, value))
self._cnf = cnf
return self._cnf
@property
def solution(self):
if self._solution is None:
self._solution = solve_one(self.cnf)
return self._solution
def _display(self, facts: Dict[Point, int]):
top_line = ' ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_top]) + ' '
print(top_line)
print('-' * len(top_line))
for index, row in enumerate(self.rows):
elems = [str(self.visible_from_left[index]) or ' ', '|'] + \
[str(facts.get(point, ' ')) for point in row] + \
['|', str(self.visible_from_right[index]) or ' ']
print(' '.join(elems))
print('-' * len(top_line))
bottom_line = ' ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_bottom]) + ' '
print(bottom_line)
print()
if __name__ == '__main__':
puzzle = TowersPuzzle()
puzzle.display_puzzle()
puzzle.display_solution()
The actual time is spent in this helper function from the used helper code that came along with the video.
def from_dnf(groups) -> 'cnf':
'Convert from or-of-ands to and-of-ors'
cnf = {frozenset()}
for group_index, group in enumerate(groups, start=1):
print(f'Group {group_index}/{len(groups)}')
nl = {frozenset([literal]): neg(literal) for literal in group}
# The "clause | literal" prevents dup lits: {x, x, y} -> {x, y}
# The nl check skips over identities: {x, ~x, y} -> True
cnf = {clause | literal for literal in nl for clause in cnf
if nl[literal] not in clause}
# The sc check removes clauses with superfluous terms:
# {{x}, {x, z}, {y, z}} -> {{x}, {y, z}}
# Should this be left until the end?
sc = min(cnf, key=len) # XXX not deterministic
cnf -= {clause for clause in cnf if clause > sc}
return list(map(tuple, cnf))
The output from pyinstrument
when using a 4x4 grid shows that the line cnf = { ... }
in here is the culprit:
_ ._ __/__ _ _ _ _ _/_ Recorded: 21:05:58 Samples: 146
/_//_/// /_\ / //_// / //_'/ // Duration: 0.515 CPU time: 0.506
/ _/ v3.4.2
Program: ./src/towers.py
0.515 <module> ../<string>:1
[7 frames hidden] .., runpy
0.513 _run_code runpy.py:62
└─ 0.513 <module> towers.py:1
├─ 0.501 display_solution towers.py:64
│ └─ 0.501 solution towers.py:188
│ ├─ 0.408 cnf towers.py:101
│ │ ├─ 0.397 from_dnf sat_utils.py:65
│ │ │ ├─ 0.329 <setcomp> sat_utils.py:73
│ │ │ ├─ 0.029 [self]
│ │ │ ├─ 0.021 min ../<built-in>:0
│ │ │ │ [2 frames hidden] ..
│ │ │ └─ 0.016 <setcomp> sat_utils.py:79
│ │ └─ 0.009 [self]
│ └─ 0.093 solve_one sat_utils.py:53
│ └─ 0.091 itersolve sat_utils.py:43
│ ├─ 0.064 translate sat_utils.py:32
│ │ ├─ 0.049 <listcomp> sat_utils.py:39
│ │ │ ├─ 0.028 [self]
│ │ │ └─ 0.021 <listcomp> sat_utils.py:39
│ │ └─ 0.015 make_translate sat_utils.py:12
│ └─ 0.024 itersolve ../<built-in>:0
│ [2 frames hidden] ..
└─ 0.009 <module> typing.py:1
[26 frames hidden] typing, abc, ..
Answer
The slow conversion from DNF to CNF in your code is a well-known challenge, especially when the DNF contains many permutations. To optimize this process or bypass it, here are a few suggestions:
1. Avoid Full Conversion to CNF
Instead of fully converting DNF to CNF, consider evaluating the possible valid permutations directly, while keeping them in DNF form. Some solvers can work with disjunctive normal form, and while they may not be as efficient as CNF solvers, they might avoid the need for conversion.
2. Use Optimized CNF Conversion Techniques
The current method of converting from DNF to CNF could be improved:
- Memoization: Cache intermediate results to avoid recomputing the same clauses.
- Simplification: Before performing the conversion, simplify the DNF by removing redundant literals or clauses. This can reduce the size of the CNF.
- Incremental Conversion: Convert DNF to CNF incrementally, instead of doing it all at once. This may help break the task into more manageable pieces.
3. Parallelize Permutation Handling
If you are generating permutations of the grid's rows or columns, this step can often be parallelized:
- Multiprocessing: Use Python's
multiprocessing
module to handle different rows/columns in parallel. This could significantly speed up the time taken for each conversion by leveraging multiple cores. - Joblib: For easy parallelization of the permutations,
joblib
can be used to parallelize loops where each loop iteration (generating the CNF) is independent.
4. Optimize Permutation Handling
Generating all permutations for a row or column is inherently costly. You might be able to reduce the search space or find optimizations for these:
- Partial Evaluation: Instead of generating all permutations and then filtering them by visibility, try to enforce constraints as you generate the permutations. This way, invalid permutations are discarded earlier in the process.
- Backtracking/Constraint Propagation: You could apply constraint satisfaction techniques like backtracking during the permutation generation to cut down the number of permutations needed.
5. Alternative SAT Solvers or Methods
You may want to experiment with different SAT solvers or libraries that might be more optimized for this kind of problem. While pycosat
is a solid solver, there are others like z3
which might perform better depending on the problem structure.
- Z3 Solver: Z3 is an SMT solver that is highly optimized and might be able to handle large CNF formulas more efficiently than
pycosat
. - MiniSat: Another fast SAT solver that you could try in place of
pycosat
.
6. Preprocessing Step to Reduce Permutations
Before generating all permutations, apply constraints to eliminate invalid ones. This can significantly reduce the number of permutations, especially for larger grids.
By implementing one or more of these techniques, you should be able to reduce the time complexity and improve performance. Let me know if you need more details or help with specific optimizations!