Slow dnf to cnf in pycosat

ghz 4days ago ⋅ 6 views

Question in short

To have a proper input for pycosat, is there a way to speed up calculation from dnf to cnf, or to circumvent it altogether?

Question in detail

I have been watching this video from Raymond Hettinger about modern solvers. I downloaded the code, and implemented a solver for the game Towers in it. Below I share the code to do so.

Example Tower puzzle (solved):

    3 3 2 1    
---------------
3 | 2 1 3 4 | 1
3 | 1 3 4 2 | 2
2 | 3 4 2 1 | 3
1 | 4 2 1 3 | 2
---------------
    1 2 3 2    

The problem I encounter is that the conversion from dnf to cnf takes forever. Let's say that you know there are 3 towers visible from a certain line of sight. This leads to 35 possible permutations 1-5 in that row.

[('AA 1', 'AB 2', 'AC 5', 'AD 3', 'AE 4'),
 ('AA 1', 'AB 2', 'AC 5', 'AD 4', 'AE 3'),
 ...
 ('AA 3', 'AB 4', 'AC 5', 'AD 1', 'AE 2'),
 ('AA 3', 'AB 4', 'AC 5', 'AD 2', 'AE 1')]

This is a disjunctive normal form: an OR of several AND statements. This needs to be converted into a conjunctive normal form: an AND of several OR statements. This is however very slow. On my Macbook Pro, it didn't finish calculating this cnf after 5 minutes for a single row. For the entire puzzle, this should be done up to 20 times (for a 5x5 grid).

What would be the best way to optimize this code, in order to make the computer able to solve this Towers puzzle?

This code is also available from this Github repository.

import string

import itertools
from sys import intern
from typing import Collection, Dict, List

from sat_utils import basic_fact, from_dnf, one_of, solve_one

Point = str


def comb(point: Point, value: int) -> str:
    """
    Format a fact (a value assigned to a given point), and store it into the interned strings table

    :param point: Point on the grid, characterized by two letters, e.g. AB
    :param value: Value of the cell on that point, e.g. 2
    :return: Fact string 'AB 2'
    """

    return intern(f'{point} {value}')


def visible_from_line(line: Collection[int], reverse: bool = False) -> int:
    """
    Return how many towers are visible from the given line

    >>> visible_from_line([1, 2, 3, 4])
    4
    >>> visible_from_line([1, 4, 3, 2])
    2
    """

    visible = 0
    highest_seen = 0
    for number in reversed(line) if reverse else line:
        if number > highest_seen:
            visible += 1
            highest_seen = number
    return visible


class TowersPuzzle:
    def __init__(self):
        self.visible_from_top = [3, 3, 2, 1]
        self.visible_from_bottom = [1, 2, 3, 2]
        self.visible_from_left = [3, 3, 2, 1]
        self.visible_from_right = [1, 2, 3, 2]
        self.given_numbers = {'AC': 3}

        # self.visible_from_top = [3, 2, 1, 4, 2]
        # self.visible_from_bottom = [2, 2, 4, 1, 2]
        # self.visible_from_left = [3, 2, 3, 1, 3]
        # self.visible_from_right = [2, 2, 1, 3, 2]

        self._cnf = None
        self._solution = None

    def display_puzzle(self):
        print('*** Puzzle ***')
        self._display(self.given_numbers)

    def display_solution(self):
        print('*** Solution ***')
        point_to_value = {point: value for point, value in [fact.split() for fact in self.solution]}
        self._display(point_to_value)

    @property
    def n(self) -> int:
        """
        :return: Size of the grid
        """

        return len(self.visible_from_top)

    @property
    def points(self) -> List[Point]:
        return [''.join(letters) for letters in itertools.product(string.ascii_uppercase[:self.n], repeat=2)]

    @property
    def rows(self) -> List[List[Point]]:
        """
        :return: Points, grouped per row
        """

        return [self.points[i:i + self.n] for i in range(0, self.n * self.n, self.n)]

    @property
    def cols(self) -> List[List[Point]]:
        """
        :return: Points, grouped per column
        """

        return [self.points[i::self.n] for i in range(self.n)]

    @property
    def values(self) -> List[int]:
        return list(range(1, self.n + 1))

    @property
    def cnf(self):
        if self._cnf is None:
            cnf = []

            # Each point assigned exactly one value
            for point in self.points:
                cnf += one_of(comb(point, value) for value in self.values)

            # Each value gets assigned to exactly one point in each row
            for row in self.rows:
                for value in self.values:
                    cnf += one_of(comb(point, value) for point in row)

            # Each value gets assigned to exactly one point in each col
            for col in self.cols:
                for value in self.values:
                    cnf += one_of(comb(point, value) for point in col)

            # Set visible from left
            if self.visible_from_left:
                for index, row in enumerate(self.rows):
                    target_visible = self.visible_from_left[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(row, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from right
            if self.visible_from_right:
                for index, row in enumerate(self.rows):
                    target_visible = self.visible_from_right[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm, reverse=True) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(row, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from top
            if self.visible_from_top:
                for index, col in enumerate(self.cols):
                    target_visible = self.visible_from_top[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(col, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set visible from bottom
            if self.visible_from_bottom:
                for index, col in enumerate(self.cols):
                    target_visible = self.visible_from_bottom[index]
                    if not target_visible:
                        continue
                    possible_perms = []
                    for perm in itertools.permutations(range(1, self.n + 1)):
                        if visible_from_line(perm, reverse=True) == target_visible:
                            possible_perms.append(tuple(
                                comb(point, value)
                                for point, value in zip(col, perm)
                            ))
                    cnf += from_dnf(possible_perms)

            # Set given numbers
            for point, value in self.given_numbers.items():
                cnf += basic_fact(comb(point, value))

            self._cnf = cnf

        return self._cnf

    @property
    def solution(self):
        if self._solution is None:
            self._solution = solve_one(self.cnf)
        return self._solution

    def _display(self, facts: Dict[Point, int]):
        top_line = '    ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_top]) + '    '
        print(top_line)
        print('-' * len(top_line))
        for index, row in enumerate(self.rows):
            elems = [str(self.visible_from_left[index]) or ' ', '|'] + \
                    [str(facts.get(point, ' ')) for point in row] + \
                    ['|', str(self.visible_from_right[index]) or ' ']
            print(' '.join(elems))
        print('-' * len(top_line))
        bottom_line = '    ' + ' '.join([str(elem) if elem else ' ' for elem in self.visible_from_bottom]) + '    '
        print(bottom_line)
        print()


if __name__ == '__main__':
    puzzle = TowersPuzzle()
    puzzle.display_puzzle()
    puzzle.display_solution()

The actual time is spent in this helper function from the used helper code that came along with the video.

def from_dnf(groups) -> 'cnf':
    'Convert from or-of-ands to and-of-ors'
    cnf = {frozenset()}
    for group_index, group in enumerate(groups, start=1):
        print(f'Group {group_index}/{len(groups)}')
        nl = {frozenset([literal]): neg(literal) for literal in group}
        # The "clause | literal" prevents dup lits: {x, x, y} -> {x, y}
        # The nl check skips over identities: {x, ~x, y} -> True
        cnf = {clause | literal for literal in nl for clause in cnf
               if nl[literal] not in clause}
        # The sc check removes clauses with superfluous terms:
        #     {{x}, {x, z}, {y, z}} -> {{x}, {y, z}}
        # Should this be left until the end?
        sc = min(cnf, key=len)  # XXX not deterministic
        cnf -= {clause for clause in cnf if clause > sc}
    return list(map(tuple, cnf))

The output from pyinstrument when using a 4x4 grid shows that the line cnf = { ... } in here is the culprit:

  _     ._   __/__   _ _  _  _ _/_   Recorded: 21:05:58  Samples:  146
 /_//_/// /_\ / //_// / //_'/ //     Duration: 0.515     CPU time: 0.506
/   _/                      v3.4.2

Program: ./src/towers.py

0.515 <module>  ../<string>:1
   [7 frames hidden]  .., runpy
      0.513 _run_code  runpy.py:62
      └─ 0.513 <module>  towers.py:1
         ├─ 0.501 display_solution  towers.py:64
         │  └─ 0.501 solution  towers.py:188
         │     ├─ 0.408 cnf  towers.py:101
         │     │  ├─ 0.397 from_dnf  sat_utils.py:65
         │     │  │  ├─ 0.329 <setcomp>  sat_utils.py:73
         │     │  │  ├─ 0.029 [self]
         │     │  │  ├─ 0.021 min  ../<built-in>:0
         │     │  │  │     [2 frames hidden]  ..
         │     │  │  └─ 0.016 <setcomp>  sat_utils.py:79
         │     │  └─ 0.009 [self]
         │     └─ 0.093 solve_one  sat_utils.py:53
         │        └─ 0.091 itersolve  sat_utils.py:43
         │           ├─ 0.064 translate  sat_utils.py:32
         │           │  ├─ 0.049 <listcomp>  sat_utils.py:39
         │           │  │  ├─ 0.028 [self]
         │           │  │  └─ 0.021 <listcomp>  sat_utils.py:39
         │           │  └─ 0.015 make_translate  sat_utils.py:12
         │           └─ 0.024 itersolve  ../<built-in>:0
         │                 [2 frames hidden]  ..
         └─ 0.009 <module>  typing.py:1
               [26 frames hidden]  typing, abc, ..

Answer

The slow conversion from DNF to CNF in your code is a well-known challenge, especially when the DNF contains many permutations. To optimize this process or bypass it, here are a few suggestions:

1. Avoid Full Conversion to CNF

Instead of fully converting DNF to CNF, consider evaluating the possible valid permutations directly, while keeping them in DNF form. Some solvers can work with disjunctive normal form, and while they may not be as efficient as CNF solvers, they might avoid the need for conversion.

2. Use Optimized CNF Conversion Techniques

The current method of converting from DNF to CNF could be improved:

  • Memoization: Cache intermediate results to avoid recomputing the same clauses.
  • Simplification: Before performing the conversion, simplify the DNF by removing redundant literals or clauses. This can reduce the size of the CNF.
  • Incremental Conversion: Convert DNF to CNF incrementally, instead of doing it all at once. This may help break the task into more manageable pieces.

3. Parallelize Permutation Handling

If you are generating permutations of the grid's rows or columns, this step can often be parallelized:

  • Multiprocessing: Use Python's multiprocessing module to handle different rows/columns in parallel. This could significantly speed up the time taken for each conversion by leveraging multiple cores.
  • Joblib: For easy parallelization of the permutations, joblib can be used to parallelize loops where each loop iteration (generating the CNF) is independent.

4. Optimize Permutation Handling

Generating all permutations for a row or column is inherently costly. You might be able to reduce the search space or find optimizations for these:

  • Partial Evaluation: Instead of generating all permutations and then filtering them by visibility, try to enforce constraints as you generate the permutations. This way, invalid permutations are discarded earlier in the process.
  • Backtracking/Constraint Propagation: You could apply constraint satisfaction techniques like backtracking during the permutation generation to cut down the number of permutations needed.

5. Alternative SAT Solvers or Methods

You may want to experiment with different SAT solvers or libraries that might be more optimized for this kind of problem. While pycosat is a solid solver, there are others like z3 which might perform better depending on the problem structure.

  • Z3 Solver: Z3 is an SMT solver that is highly optimized and might be able to handle large CNF formulas more efficiently than pycosat.
  • MiniSat: Another fast SAT solver that you could try in place of pycosat.

6. Preprocessing Step to Reduce Permutations

Before generating all permutations, apply constraints to eliminate invalid ones. This can significantly reduce the number of permutations, especially for larger grids.

By implementing one or more of these techniques, you should be able to reduce the time complexity and improve performance. Let me know if you need more details or help with specific optimizations!