Learning draughts evaluation functions using Keras/TensorFlow

Discussion about development of draughts in the time of computer and Internet.
Post Reply
Rein Halbersma
Posts: 1722
Joined: Wed Apr 14, 2004 16:04
Contact:

Learning draughts evaluation functions using Keras/TensorFlow

Post by Rein Halbersma » Mon Nov 30, 2020 15:33

Introduction
Over the last week or so, I've written a Python script that creates a draughts evaluation function as a Keras/TensorFlow neural network that can be trained with gradient descent. The script can learn a draughts evaluation function using both positional features as well as the Scan-like pattern indices that are currently in use by all the top draughts engines, including Kingsrow.

Acknowledgement
Ed Gilbert has been extremely helpful explaining his Kingsrow evaluation function, thanks Ed! As a disclaimer, here is the information exchanged between us:
  • I have only seen Ed's low-level C++ eval function that combines the features + pattern indices and their weights into a score, his raw training data, his trained weights and the corresponding score predictions.
  • I have not seen Ed's top-level C++ eval function that translates a position into features and pattern indices, nor his actual gradient descent code, nor his training data generation code.
Model Implementation
The Kingsrow eval can be parameterized by the following constants:

Code: Select all

num_features = 5    # balance, tempo, men, multi kings and first king
num_patterns = 4    # partially overlapping areas
num_phases   = 2    # opening and endgame
num_views    = 2    # normal and mirrored
num_pieces   = 3    # empty, black and white men
num_squares  = 12   # 4x6 areas on a checkerboard
index_shape  = (num_pieces**num_squares, num_patterns, num_phases)
Given these or similar constants, and the core of the script is a handful of lines of Python that call the excellent Keras library on top of TensorFlow. Below the code that reimplements the current Kingsrow evaluation function in Keras:

Code: Select all

features = keras.Input(shape=(num_features,), name='features')
patterns = keras.Input(shape=(num_patterns, num_views), name='patterns', dtype='int32')
phases   = keras.Input(shape=(num_phases,), name='phases')

feature_scores       = keras.layers.Dense(units=num_phases, use_bias=False, name='feature-scores')(features)
pattern_scores       = SparseIndexLookup(units=index_shape, name='pattern-scores')(patterns)
scores               = keras.layers.Lambda(score_reduce, name='scores')([feature_scores, pattern_scores])
phase_weighted_score = keras.layers.Lambda(phase_reduce, name='phase-weighted-score')([scores, phases])
value_head           = keras.layers.Activation(activation='sigmoid', name='value-head')(phase_weighted_score)

model = keras.Model(inputs=[features, patterns, phases], outputs=value_head)
Almost the entire eval can be written using already implemented Keras functionality such as Input, Dense, Lambda, Activation and Model. The tricky part was getting the sparse index lookup and the scores per phase right. This was solved using the Tensorflow functions tf.reduce_sum and tf.gather. See the implementations of my SparseIndexLookup layer class and score_reduce() and phase_reduce() functions.
Spoiler:

Code: Select all

def score_reduce(x):
    return tf.reduce_sum(x, axis=0)


def phase_reduce(x):
    # The default keepdims=False would give output shape (None,). 
    # Instead, keepdims=True will give output shape (None, 1).
    # This enables BatchNormalization applying axis=-1 on its input
    # when following a Lambda layer calling this function.
    return tf.reduce_sum(x[0] * x[1], axis=1, keepdims=True)


class SparseIndexLookup(keras.layers.Layer):
    def __init__(self, units=None, kernel_regularizer=None, **kwargs):
        super(SparseIndexLookup, self).__init__(**kwargs)
        self.units = units
        self.kernel_regulalizer = kernel_regularizer
        self.W = self.add_weight(
            name='weights',
            shape=self.units, 
            initializer='zeros',
            regularizer=self.kernel_regulalizer, 
            trainable=True
        )

    def call(self, inputs):
        scores = [ 
            tf.reduce_sum([
                tf.gather(self.W[:, pattern, :], inputs[:, pattern, view], axis=0) 
                for pattern in range(4)
            ], axis=0) 
            for view in range(2) 
        ]
        return scores[0] - scores[1]

    def get_config(self):
        config = super(SparseIndexLookup, self).get_config()
        config.update({
            'units': self.units,
            'kernel_regularizer': self.kernel_regulalizer
        })
        return config
Calling model.summary() on the evaluation function gives the following table:
Spoiler:

Code: Select all

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
features (InputLayer)           [(None, 5)]          0
__________________________________________________________________________________________________
patterns (InputLayer)           [(None, 4, 2)]       0
__________________________________________________________________________________________________
feature-scores (Dense)          (None, 2)            10          features[0][0]
__________________________________________________________________________________________________
pattern-scores (SparseIndexLook (None, 2)            4251528     patterns[0][0]
__________________________________________________________________________________________________
scores (Lambda)                 (None, 2)            0           feature-scores[0][0]
                                                                 pattern-scores[0][0]
__________________________________________________________________________________________________
phases (InputLayer)             [(None, 2)]          0
__________________________________________________________________________________________________
phase-weighted-score (Lambda)   (None, 1)            0           scores[0][0]
                                                                 phases[0][0]
__________________________________________________________________________________________________
value-head (Activation)         (None, 1)            0           phase-weighted-score[0][0]
==================================================================================================
Total params: 4,251,538
Trainable params: 4,251,538
Non-trainable params: 0
__________________________________________________________________________________________________
In total there are 5 * 2 + 3^12 * 4 * 2 = ~4.25 million weights. Because a picture says more than a few million weights, Keras also provides a plot_model() function that shows how the data flows along the neural network graph structure:

Image

Note that this is not a "deep learning" model such as used in AlphaZero, but rather a very shallow neural network. Furthermore it consumes very sparse inputs: namely for each of the possible 3**12 * 4 * 2 indices, only 4 * 2 are valid for each side, so only 16 pattern weights (and 10 feature weights) contribute to the eval for each position.

Manually initializing the Keras network with the Kingsrow weights reproduces the exact same predicted scores on a training set of 230 million positions. This is a good test that the Keras model in Python is a faithful reimplementation of the Kingsrow eval as Ed had written it in C++.

Model Training
Using Keras to train the weights from scratch on these 230 million positions is another story. I used an out-of-the-box gradient descent optimizer ("Adam") without any further tricks using a mini-batch size of 65K positions. Letting Keras train for 30 epochs (=30 full passes over the data) took about 30 minutes on my machine that has a $140 GPU with 4Gb of RAM (GTX 1050 Ti).

Code: Select all

model.compile(optimizer='adam', loss='mse')
model.fit(
    X_train, y_train, 
    batch_size=2**16,
    epochs=30,
    validation_data=(X_val, y_val)
)
The resulting mean squared error was 0.0367, very similar to the mean squared error in Ed's own optimization program. Ed is currently running a 12 thousand game engine match between the latest Kingsrow and a Kingsrow version having these Keras-tuned weights. The very preliminary results so far after 164 games are 1 win and all draws for the Keras trained weights versus the Kingsrow native weights. Finishing the match will take another day or so.

There is still plenty of things left to explore, such as weight regularization, learning rate schedules, different loss targets and activation functions, batch normalizations, etc., etc. Also, this script takes the training data as given, but you could easily use it inside a Reinforcement Learning training loop that alternates between data generation (so-called Policy Evaluation) and re-training the weights (so-called Policy Improvement). This is only the end of the beginning ;-)

Conclusion
In the coming days, I'm cleaning up the support code (there's also code to read and write weights from disk, and to read in training data, not shown here), and release everything on GitHub. Wrapping the model in a proper keras.Model class is also on the to-do list (this interferes with the summary() and plot_model() functions, so I need to figure that out).

Regardless of the exact match outcome, I think it's fair to say that it's now possible to get out-of-the-box world class performance using Keras/TensorFlow, given of course that you have a good training game generation pipeline set up already. This should make good on my claim from last May that it is in principle possible to use a professional optimization library for learning draughts evaluation functions.

BertTuyt
Posts: 1592
Joined: Wed Sep 01, 2004 19:42

Re: Learning draughts evaluation functions using Keras/TensorFlow

Post by BertTuyt » Mon Nov 30, 2020 16:01

Rein, thanks for sharing, I will certainly also redo the optimization of my evaluation based upon your framework.

If also NNUE will reveal similar results, than we also get away with manual selecting of features and pattern-geometries.
Then, in you words, we are really at the new beginning of the end, and we will boldly go where no man has gone before....

Bert

Krzysztof Grzelak
Posts: 1368
Joined: Thu Jun 20, 2013 17:16
Real name: Krzysztof Grzelak

Re: Learning draughts evaluation functions using Keras/TensorFlow

Post by Krzysztof Grzelak » Mon Nov 30, 2020 16:02

This is where the problem starts - and the problem is the graphics card.

BertTuyt
Posts: 1592
Joined: Wed Sep 01, 2004 19:42

Re: Learning draughts evaluation functions using Keras/TensorFlow

Post by BertTuyt » Mon Nov 30, 2020 16:05

At least you can run Tensorflow on a CPU (and use all cores), not sure how long it will take.
But if the time is below 1 day then I would not consider this is a major problem.

Bert

Rein Halbersma
Posts: 1722
Joined: Wed Apr 14, 2004 16:04
Contact:

Re: Learning draughts evaluation functions using Keras/TensorFlow

Post by Rein Halbersma » Mon Nov 30, 2020 16:19

Krzysztof Grzelak wrote:
Mon Nov 30, 2020 16:02
This is where the problem starts - and the problem is the graphics card.
https://www.ebay.com/b/NVIDIA-GeForce-G ... _110675693

EDIT: note that you don't need a graphics card to run Kingsrow or Scan with such Keras tuned weights. You can simply import the weights directly into the C++ eval that is executed on the CPU. The GPU is only used during training. This is different for deep learning networks such as AlphaZero, LeelaZero, where the engine als benefits from a big expensive graphics card.
Last edited by Rein Halbersma on Mon Nov 30, 2020 16:31, edited 1 time in total.

Rein Halbersma
Posts: 1722
Joined: Wed Apr 14, 2004 16:04
Contact:

Re: Learning draughts evaluation functions using Keras/TensorFlow

Post by Rein Halbersma » Mon Nov 30, 2020 16:27

BertTuyt wrote:
Mon Nov 30, 2020 16:01
Rein, thanks for sharing, I will certainly also redo the optimization of my evaluation based upon your framework.

If also NNUE will reveal similar results, than we also get away with manual selecting of features and pattern-geometries.
Then, in you words, we are really at the new beginning of the end, and we will boldly go where no man has gone before....

Bert
I don't know. The material features still contribute Elo to an eval based purely on Scan-like patterns.

As Fabien pointed out (either here or in an email to me, can't remember), you need to start from first principles: what is essential input from the game's perspective that you want your eval to capture? It was his original insight that localized patterns were fruitful primitives for draughts. Once you have the primitives, you can start optimizing.

Maybe raw inputs are good enough as well (it was for checkers). Or maybe you need Piece * Neighbor interaction as I described in the other thread, combined using LocallyConnected convolution layers.

Rein Halbersma
Posts: 1722
Joined: Wed Apr 14, 2004 16:04
Contact:

Re: Learning draughts evaluation functions using Keras/TensorFlow

Post by Rein Halbersma » Mon Nov 30, 2020 16:35

BertTuyt wrote:
Mon Nov 30, 2020 16:05
At least you can run Tensorflow on a CPU (and use all cores), not sure how long it will take.
But if the time is below 1 day then I would not consider this is a major problem.

Bert
I'm running the same optimization on the CPU now: it's taking now 4.5 minutes per epoch instead of 1 minute. So total training time for 30 epochs should be 2h15m instead of 30min. That's on 231 million positions, using 10 out of 12 hyperthreads on my 6 core CPU.

Krzysztof Grzelak
Posts: 1368
Joined: Thu Jun 20, 2013 17:16
Real name: Krzysztof Grzelak

Re: Learning draughts evaluation functions using Keras/TensorFlow

Post by Krzysztof Grzelak » Mon Nov 30, 2020 17:14

Rein Halbersma wrote:
Mon Nov 30, 2020 16:19
https://www.ebay.com/b/NVIDIA-GeForce-G ... _110675693

EDIT: note that you don't need a graphics card to run Kingsrow or Scan with such Keras tuned weights. You can simply import the weights directly into the C++ eval that is executed on the CPU. The GPU is only used during training. This is different for deep learning networks such as AlphaZero, LeelaZero, where the engine als benefits from a big expensive graphics card.
And the better effect during the game will be when you use the graphics card (expensive) or CPU

Rein Halbersma
Posts: 1722
Joined: Wed Apr 14, 2004 16:04
Contact:

Re: Learning draughts evaluation functions using Keras/TensorFlow

Post by Rein Halbersma » Mon Nov 30, 2020 17:17

Krzysztof Grzelak wrote:
Mon Nov 30, 2020 17:14
Rein Halbersma wrote:
Mon Nov 30, 2020 16:19
https://www.ebay.com/b/NVIDIA-GeForce-G ... _110675693

EDIT: note that you don't need a graphics card to run Kingsrow or Scan with such Keras tuned weights. You can simply import the weights directly into the C++ eval that is executed on the CPU. The GPU is only used during training. This is different for deep learning networks such as AlphaZero, LeelaZero, where the engine als benefits from a big expensive graphics card.
And the better effect during the game will be when you use the graphics card (expensive) or CPU
Forget about the GPU. For Scan, Kingsrow and other programs currently using patterns, nothing changes, not one line of code, they can choose to import other weights, that's it.
Last edited by Rein Halbersma on Mon Nov 30, 2020 17:30, edited 1 time in total.

Krzysztof Grzelak
Posts: 1368
Joined: Thu Jun 20, 2013 17:16
Real name: Krzysztof Grzelak

Re: Learning draughts evaluation functions using Keras/TensorFlow

Post by Krzysztof Grzelak » Mon Nov 30, 2020 17:23

Thanks for the answer Rein.

Rein Halbersma
Posts: 1722
Joined: Wed Apr 14, 2004 16:04
Contact:

Re: Learning draughts evaluation functions using Keras/TensorFlow

Post by Rein Halbersma » Tue Dec 01, 2020 14:59

Ed communicated the engine match results: total 0.502 score, 12640 games, 79 wins, 35 losses, 12526 draws, 0 unknown. That's an
Elo difference of 1.2 in favor of the Keras-tuned weights. That's a very small but statistically significant difference (4.12 standard deviations, Likelihood-of-superiority is 99.99811%).

Madeleine Birchfield
Posts: 12
Joined: Mon Jun 22, 2020 12:36
Real name: Madeleine Birchfield

Re: Learning draughts evaluation functions using Keras/TensorFlow

Post by Madeleine Birchfield » Wed Dec 02, 2020 21:53

Why Tensorflow and not PyTorch?

Rein Halbersma
Posts: 1722
Joined: Wed Apr 14, 2004 16:04
Contact:

Re: Learning draughts evaluation functions using Keras/TensorFlow

Post by Rein Halbersma » Wed Dec 02, 2020 23:23

Madeleine Birchfield wrote:
Wed Dec 02, 2020 21:53
Why Tensorflow and not PyTorch?
Because Keras has so little boilerplate. But if you prefer, I can write a PyTorch version.

Post Reply