Résolution de puzzle

Où l'on utilise un solveur pour résoudre un problème…

Nous allons nous intéresser à la résolution d'un puzzle proposé sur le site de Jane Street. Cette société de fintech a deux caractéristiques intéressantes : tout d'abord, il y a une quinzaine d'années, elle a changé tout son infrasctructure informatique pour tout redévelopper en Ocaml. Depuis, elle est devenu un acteur majeur du développement de ce langage. Une autre caractéristique, celle qui nous intéresse aujourd'hui, est qu'elle propose tous les mois des puzzles souvent intéressants. Celui du mois de juin 2018 est le suivant :

La grille suivante est incomplète. Placez des nombres dans certaines cases vides de telle sorte que la grille contient au total un nombre 1, deux nombres 2, …, neuf nombres 9. De plus, chaque ligne et chaque colonne doit contenir exactement quatres nombres dont la somme est égale à 20. Enfin, les cases numérotées doivent une région connexe et chaque sous-carré de taille 2 x 2 doit contenir au moins une case vide.

La réponse à trouver est le produit des aires des groupes connexes de cases vides de la grille complète.

4
6
3
6
5
5
4
4
7
2
7
4
1

Nous allons résoudre ce problème en deux étapes. Dans un premier temps, à l'aide d'un solveur de satisfiabilité modulo des théories, nous allons traduire une partie des contraintes sous forme logique et utiliser le solveur pour déterminer les solutions compatibles avec ces contraintes. Dans un second temps, nous utiliserons l'algorithme Union-Find pour traîter les problèmes de connexité.

Présentation de Z3

Notre outil principal va être le solveur Z3 développé par Microsoft, qui permet de vérifier si des contraintes logiques et arithmétiques peuvent être satisfaites et donner, le cas échéant, des valeurs aux variables qui satisfont les contraintes.

Pour l'utiliser en Python, il vous faudra probablement l'installer (avec pip ou de façon plus manuelle comme expliqué sur la page de présentation du solveur).

Voici un exemple très simple d'utilisation. Commençons par importer le module et par déclarer quelques variables entières.

>>> from z3 import *
>>> x = Int('x')
>>> y = Int('y')
>>> z = Int('z')

Nous allons ensuite définir un solveur auquel nous allons ajouter un certain nombre de contraintes :

>>> s = Solver()
>>> s.add(x*x + y*y == z*z)
>>> s.add(z <= 10)
>>> s.add(z >= 0)
>>> s.add(x >= 1)
>>> s.add(y >= 1)

Ici, nous avons donc comme contraintes que

x2+y2=z2,0z10,1xet1y&#123;`x^2 + y^2 = z^2, \\qquad 0 \\leq z \\leq 10, \\qquad 1 \\leq x \\qquad \\mathrm{et&#125; \\qquad 1 \\leq y`}

Est-il possible de satisfaire ces contraintes ?

>>> s.check()
sat

La réponse sat indique que c'est le cas, s a réussi à déterminer des valeurs qui conviennent. Pour récupérer ces valeurs, on procède ainsi :

>>> s.model()
[z = 10, y = 8, x = 6]

Le modèle proposé peut s'utiliser comme un dictionnaire :

>>> m = s.model()
>>> m[x]
6

Par contre, ici, m[x] n'est pas vraiment un entier :

>>> type(m[x])
<class 'z3.z3.IntNumRef'>

et on ne peut pas directement tester sa valeur :

>>> m[x] == 2
6 == 2
>>> type(m[x] == 2)
<class 'z3.z3.BoolRef'>

Pour cela, il faut utiliser as_long() pour le convertir en véritable entier :

>>> m[x].as_long()
6
>>> type(m[x].as_long())
<class 'int'>
>>> m[x].as_long() == 2
False

Pour plus de renseignements, je vous conseille l'excellente page d'introduction à Z3 en Python d'Eric Pony.

Modélisation du problème

Nous allons tout d'abord créer une variable pour chaque case de la grille :

C = [[Int("c{}{}".format(l, c)) for c in range(7)] for l in range(7)]

Avec ces variables, nous allons coder les différentes contraintes du puzzle à l'aide d'une fonction probleme prenant C comme argument :

def probleme(C):
    contraintes = []
    ...
  1. Chaque case non vide contient un nombre compris entre 1 et 9. Il est très pratique de supposer de plus qu'une case vide contient le nombre 0.
    ...
    # encadrement des valeurs des cellules
    for l in range(7):
        for c in range(7):
            contraintes.append(And(C[l][c] >= 0, C[l][c] <= 9))
    ...
  1. La somme de chaque ligne et de chaque colonne est égale à 20. On utilise la fonction Sum de z3 pour simplifier l'écriture de la contrainte.
    ...
    # sommes des lignes
    for l in range(7):
        contraintes.append(Sum([C[l][c] for c in range(7)]) == 20)
    # sommes des colonnes
    for c in range(7):
        contraintes.append(Sum([C[l][c] for l in range(7)]) == 20)
    ...
  1. Chaque carré 2 x 2 contient au moins une case vide :
    ...
    # les carrés 2x2 ne sont pas entièrement pleins
    for l in range(6):
        for c in range(6):
            contraintes.append(Or(
                C[l][c] == 0,
                C[l + 1][c] == 0,
                C[l][c + 1] == 0,
                C[l + 1][c + 1] == 0
            ))
    ...
  1. Chaque ligne et chaque colonne contient exactement 4 valeurs non nulles, et donc 3 valeurs nulles. Pour traîter cette contrainte, nous avons besoin d'énumérer les ensembles de 3 entiers compris entre 0 et 6. Pour cela, nous allons utiliser la bibliothèque itertools qui permet de faire cela très simplement :
>>> import itertools
>>> pos = itertools.combinations(range(7), 3)
>>> list(pos)
[(0, 1, 2), (0, 1, 3), (0, 1, 4), (0, 1, 5), (0, 1, 6), (0, 2, 3),
 (0, 2, 4), (0, 2, 5), (0, 2, 6), (0, 3, 4), (0, 3, 5), (0, 3, 6),
 (0, 4, 5), (0, 4, 6), (0, 5, 6), (1, 2, 3), (1, 2, 4), (1, 2, 5),
 (1, 2, 6), (1, 3, 4), (1, 3, 5), (1, 3, 6), (1, 4, 5), (1, 4, 6),
 (1, 5, 6), (2, 3, 4), (2, 3, 5), (2, 3, 6), (2, 4, 5), (2, 4, 6),
 (2, 5, 6), (3, 4, 5), (3, 4, 6), (3, 5, 6), (4, 5, 6)]

Cela nous permet de traduire facilement la contrainte qu'exactement 3 cases par ligne et par colonne sont vides. Notons l'utilisation de la construction ... if ... else ... pour pouvoir faire un test utilisé comme une expression (ici, à l'intérieur d'une liste en compréhension) (Dans de nombreux langages, on note cela en utilisant l'opérateur ternaire test ? si_vrai : si_faux alors qu'en Ocaml et plus généralement avec des langages fonctionnels, un if classique convient car il n'y a pas de distinction entre expression et instruction).

    ...
    # 3 cases vides par ligne
    for l in range(7):
        contraintes.append(Or([
            And([
                C[l][c] == 0 if c in positions else C[l][c] != 0
                for c in range(7)
            ])
            for positions in itertools.combinations(range(7), 3)
        ]))
    # 3 cases vides par colonne
    for c in range(7):
        contraintes.append(Or([
            And([
                C[l][c] == 0 if l in positions else C[l][c] != 0
                for l in range(7)
            ])
            for positions in itertools.combinations(range(7), 3)
        ]))
    ...
  1. Reste à prendre en compte les données initiales de la grille. Pour cela, nous allons la traduire sous forme d'une chaîne de caractères :
    ...
    # valeurs initiales de la grille
    grille = [
        ".4.....",
        "..63..6",
        ".....55",
        "...4...",
        "47.....",
        "2..74..",
        ".....1."
    ]
    for l in range(7):
        for c in range(7):
            valeur = grille[l][c]
            if valeur != ".":
                contraintes.append(C[l][c] == int(valeur))
    ...

En résumé, on a la fonction suivante :

def probleme(C):
    s = Solver()
    # encadrement des valeurs des cellules
    for l in range(7):
        for c in range(7):
            s.add(And(C[l][c] >= 0, C[l][c] <= 7))
    # sommes des lignes
    for l in range(7):
        s.add(Sum([C[l][c] for c in range(7)]) == 20)
    # sommes des colonnes
    for c in range(7):
        s.add(Sum([C[l][c] for l in range(7)]) == 20)
    # les carrés 2x2 ne sont pas entièrement pleins
    for l in range(6):
        for c in range(6):
            s.add(Or(
                C[l][c] == 0,
                C[l + 1][c] == 0,
                C[l][c + 1] == 0,
                C[l + 1][c + 1] == 0
            ))
    # 3 cases vides par ligne
    for l in range(7):
        s.add(Or([
            And([
                C[l][c] == 0 if c in positions else C[l][c] != 0
                for c in range(7)
            ])
            for positions in itertools.combinations(range(7), 3)
        ]))
    # 3 cases vides par colonne
    for c in range(7):
        s.add(Or([
            And([
                C[l][c] == 0 if l in positions else C[l][c] != 0
                for l in range(7)
            ])
            for positions in itertools.combinations(range(7), 3)
        ]))
    # valeurs initiales de la grille
    grille = [
        ".4.....",
        "..63..6",
        ".....55",
        "...4...",
        "47.....",
        "2..74..",
        ".....1."
    ]
    for l in range(7):
        for c in range(7):
            valeur = grille[l][c]
            if valeur != ".":
                s.add(C[l][c] == int(valeur))
    # on a fini
    return s
>>> s = probleme(C)
>>> s.check()
sat
>>> s.model()
[c23 = 0,
 c43 = 0,
 ...
 c42 = 2]

Pour rendre cela plus lisible, définissons une petite fonction d'affichage. On utilise as_long() pour convertir en entier les valeurs du modèle.

def affiche_modele(m):
    for l in range(7):
        for c in range(7):
            valeur = m[C[l][c]].as_long()
            if valeur:
                print(valeur, end=" ")
            else:
                print(".", end=" ")
        print("")

On a ainsi :

>>> affiche_modele(s.model())
. 4 . . 4 7 5
. . 6 3 5 . 6
. 3 . . 7 5 5
7 . 5 4 . . 4
4 7 2 . . 7 .
2 . 7 7 4 . .
7 6 . 6 . 1 .

La solution satisfait bien les contraintes que l'on a programmé jusqu'ici, mais le problème n'est pas résolu : l'ensemble des cases remplies n'est pas connexe et les nombres n'apparaissent pas le bon nombre de fois.

Toutes les solutions

Si l'on redemande un modèle, même après un nouveau check(), on obtient le même modèle.

>>> s.check()
sat
>>> s.model()
[c23 = 0,
 c43 = 0,
 ...
 c42 = 2]

La raison est que le solveur va utiliser le même algorithme. Et comme la situation n'a pas changée… le résultat est lui aussi inchangé. Il faut donc ajouter une nouvelle contrainte pour être sûr de ne pas retomber sur le résultat précédent. Mais un nouveau modèle signifie avoir des nouvelles valeurs pour les variables C[l][c]. Plus précisément, il faut qu'au moins une valeur diffère. La contrainte est donc

c0,0m0,0c0,1m0,1  c6,6m6,6&#123;`c_{0,0 &#125; \\neq m_{0, 0} \\vee c_{0, 1} \\neq m_{0, 1} \\vee \\ \\cdots \\ \\vee c_{6, 6} \\neq m_{6, 6}`}

où « &#123;`\\vee`&#125; » désigne la disjonction, ci,j&#123;`c_{i,j&#125;`} la variable correspondant de la case (i,j)&#123;`(i,j)`&#125; et mi,j&#123;`m_{i,j&#125;`} la valeur donnée par le modèle. Cela se traduit par :

>>> s.add(Or([C[l][c] != m[C[l][c]] for l in range(7) for c in range(7)]))

On ne peut donc obtenir à nouveau le modèle précédent. Vérifions cela :

>>> s.check()
sat
>>> affiche_modele(s.model())
7 4 . . 7 . 2
. 5 6 3 . . 6
. 4 . . 6 5 5
7 . 2 4 . 7 .
4 7 6 . 3 . .
2 . . 7 4 7 .
. . 6 6 . 1 7

Il faut donc répéter ce processus jusqu'à ce que les contraintes ne puissent plus être satisfaites pour obtenir toutes les solutions. On en déduit la fonction suivante :

def solutions_temporaires():
    sols = []
    s = probleme(C)
    while s.check() == sat:
        m = s.model()
        sols.append(m)
        # on interdit de retrouver la solution actuelle
        s.add(Or([C[l][c] != m[C[l][c]] for l in range(7) for c in range(7)]))
    return sols

Si l'on a maintenant obtenu toutes les solutions, il y en a encore trop, puisque certaines contraintes restent à prendre en compte :

>>> sols = solutions_temporaires()
>>> len(sols)
1140

Une des contraintes restantes est très simple à mettre en place, celle concernant le nombre d'occurences de chaque valeur.

def teste_occurences(m):
    vals = [m[C[l][c]].as_long() for l in range(7) for c in range(7)]
    for v in range(1, 8):
        if vals.count(v) != v:
            return False
    return True

On en déduit une nouvelle fonction de détermination des solutions :

def solutions_temporaires2():
    sols = []
    s = probleme(C)
    while s.check() == sat:
        m = s.model()
        if teste_occurences(m):
            sols.append(m)
        # on interdit de retrouver la solution actuelle
        s.add(Or([C[l][c] != m[C[l][c]] for l in range(7) for c in range(7)]))
    return sols

Voyons combien il reste de solutions à traîter :

>>> len(solutions_temporaires2())
21

On progresse…

Vérification de la connexité

Pour traduire les contraintes de connexité (les cases numérotées doivent être connexes, en un seul morceau), nous allons utiliser l'algorithme Union-Find.

Présentation de l'algorithme

L'idée est représenter une partition d'ensembles d'entiers à l'aide d'un tableau.

Supposons que l'ensemble des entiers de 0&#123;`0`&#125; à 9&#123;`9`&#125; soit partitionné ainsi :

{0,4},{2},{3,6,8,9},{1,5,7}&#123;`\\{ 0, 4 \\&#125;, \\{ 2 \\}, \\{ 3, 6, 8, 9 \\}, \\{ 1, 5, 7 \\}`}

Une représentation possible de cette partition est le tableau t suivant :

>>> t = [  4,  7, -1,  8, -2, -3, -4,  5,  6,  6]
>>> for i in range(len(t)):
...     print("t[{}] = {:2}".format(i,t[i]), end="\n" if i % 5 == 4 else "   ")
...
t[0] =  4   t[1] =  7   t[2] = -1   t[3] =  8   t[4] = -2
t[5] = -3   t[6] = -4   t[7] =  5   t[8] =  6   t[9] =  6

Remarquons tout d'abord que dans le tableau t, seuls 4 indices ont une valeur strictement négatives :

>>> [i for i in range(len(t)) if t[i] < 0]
[2, 4, 5, 6]

Il y en a exactement un par sous-ensemble de la partition. C'est le représentant du sous-ensemble. Par exemple, {3,6,8,9}&#123;`\\{ 3, 6, 8, 9\\&#125;`} a 6 pour représentant. La valeur associée dans t à chacun de ces représentants est l'opposé du nombre d'éléments du sous-ensemble correspondant. Ainsi, t[5] = -3 dont le sous-ensemble représenté par 5 comporte 3 éléments.

L'interprétation des autres valeurs présentantes dans t est la suivante : en notant y = t[x], si y >= 0 cela signifie que x et y sont dans le même sous-ensemble de la partition, et en répétant le processus avec y, on arrivera au représentant.

>>> i = 3
>>> while True:
...     print(i, end=" ")
...     if t[i] < 0:
...         print("")
...         break
...     print("->", end=" ")
...     i = t[i]
...
3 -> 8 -> 6

On en déduit la fonction find qui permet de trouver le représentant du sous-ensemble contenant un élément donné :

def find(t, x):
    while t[x] >= 0:
        x = t[x]
    return x

Reste à voir comment se construit un tel tableau. Pour cela, l'autre fonction indispensable est union qui permet de fusionner les sous-ensembles dont on a donné deux éléments. Pour cela, on calcule les représentants de chacun des sous-ensembles, et on modifie le tableau en faisant pointer l'un vers l'autre. La taille permet de faire pointer le représentant du plus petit sous-ensemble vers l'autre, afin d'éviter d'avoir des chemins trop longs lors de la recherche des représentants.

Le code de la fonction s'en déduit directement.

def union(t, a, b):
    x = find(t, a)
    y = find(t, b)
    if t[x] < t[y]:
        # y va pointer vers x
        t[x] += t[y]
        t[y] = x
    elif x != y:
        # il faut s'assurer que l'on fusionne
        # deux sous-ensembles distincts
        t[y] += t[x]
        t[x] = y

Notons pour finir qu'au départ, le tableau t ne contient que des valeurs -1, chaque élément constituant un sous-ensemble à lui tout seul.

Application au puzzle

Nous allons maintenant créer une fonction qui vérifie que toutes les cases numérotées sont connexes. Pour cela, on applique l'algorithme Union-Find en fusionnant les cases adjacentes non vides et, à la fin, on vérifie que la composente d'une des cases non vide est de taille 28, le nombre de cases non vides.

def numero_case(l, c):
    return 7 * l + c

def teste_connexite(m):
    vals = [[m[C[l][c]].as_long() for c in range(7)] for l in range(7)]
    t = [-1] * 7 * 7
    # on fusionne les cases remplies adjacentes sur une même ligne
    for l in range(7):
        for c in range(6):
            if (vals[l][c] != 0 and vals[l][c + 1] != 0):
                union(t, numero_case(l, c), numero_case(l, c + 1))
    # on fait de même verticalement
    for l in range(6):
        for c in range(7):
            if (vals[l][c] != 0 and vals[l + 1][c] != 0):
                union(t, numero_case(l, c), numero_case(l + 1, c))
    # on cherche une case non vide,
    # on sait qu'il y en a une sur la première ligne
    for c in range(7):
        if vals[0][c] > 0:
            break
    # on vérifie si sa composante connexe est de taille 28
    return t[find(t, c)] == -28

Modifions notre fonction de recherche de solutions :

def solutions():
    sols = []
    s = probleme(C)
    while s.check() == sat:
        m = s.model()
        if teste_occurences(m) and teste_connexite(m):
            sols.append(m)
        # on interdit de retrouver la solution actuelle
        s.add(Or([C[l][c] != m[C[l][c]] for l in range(7) for c in range(7)]))
    return sols

…et testons :

>>> sols = solutions()
>>> len(sols)
1

Ça se précise.

>>> affiche_modele(sols[0])
7 4 3 . 6 . .
. . 6 3 5 . 6
. . 5 . 5 5 5
. 3 6 4 . . 7
4 7 . . . 7 2
2 . . 7 4 7 .
7 6 . 6 . 1 .

Reste à calculer le score. On rappelle qu'il s'agit du produit des aires des composantes connexes formées de cases vides. On adapte la fonction teste_connexite.

def score(m):
    vals = [[m[C[l][c]].as_long() for c in range(7)] for l in range(7)]
    t = [-1] * 7 * 7
    # on fusionne les cases vides adjacentes sur une même ligne
    for l in range(7):
        for c in range(6):
            if (vals[l][c] == 0 and vals[l][c + 1] == 0):
                union(t, numero_case(l, c), numero_case(l, c + 1))
    # on fait de même verticalement
    for l in range(6):
        for c in range(7):
            if (vals[l][c] == 0 and vals[l + 1][c] == 0):
                union(t, numero_case(l, c), numero_case(l + 1, c))
    # on calcule le score, sachant que les cases remplies forment
    # chacune un sous-ensemble de taille 1, ce qui ne va pas
    # poser de problème pour le score.
    score = 1
    for v in t:
        if v < 0:
            score *= -v
    return score
>>> score(sols[0])
240

On a gagné.