Solving the data leakage problem with learning networks

Julia
MLJ
Genomics
Machine learning
Author

Tom Michoel

Published

August 22, 2025

The data leakage problem

Nowadays every student who has taken a basic course in machine learning is aware of the problem of data leakage and the need for proper separation between training, validation, and test data sets. Nonetheless, survey after survey shows that data leakage is pervasive in ML-based science. How can this discrepancy be explained?

Part of the problem, I believe, is that most ML courses make use of standard, ready-made datasets for putting theory into practice (anyone out there who hasn’t used the MNIST database?). In real life though, and certainly in biology, there is a long series of preprocessing steps to go from raw data to the nice \(X\) and \(y\) array inputs one uses in ML (check out some of the workflows on the Galaxy project to get an idea how complex data preprocessing can be).

These preprocessing workflows are often not run with later ML applications in mind, and often not even by the same person who will do the ML modelling. This leads to plenty of opportunity for data leakage. The pitfalls of leaky preprocessing have been well described (make sure also to memorize the other pitfalls described in the same article!).

What can be done to avoid leaky preprocessing? Learning networks, are the answer! Learning networks, not to be confused with deep learning or neural networks, are, in the words of its authors, simple transformations of your existing workflows which can be “exported” to define new, re-usable composite model types (models which typically have other models as hyperparameters) (see also this blog post and the technical paper).

Let’s break this down in two parts: how model composition can eliminate data leakage and what additional flexibility is offered by learning networks.

A linear pipeline: variable gene selection

We will use gene expression and drug sensitivity data from the first cancer cell line encyclopedia paper. A copy of (the relevant part of) the data is available on JuliaHub and consists of expression data for 18,926 genes and sensitivity data for 24 drugs in 474 cancer cell lines. We load data for one specific drug as our response variable \(y\) and expression data for all genes as our predictor matrix \(X\):

using CSV
using DataFrames
y = DataFrame(CSV.File("CCLE-ActArea.csv")).:"PD-0325901" #select(DataFrame(CSV.File("CCLE-ActArea.csv")), :"PD-0325901")
X = DataFrame(CSV.File("CCLE-expr.csv"))
474×18926 DataFrame
18826 columns and 449 rows omitted
Row ZBTB11-AS1 AKT3 MED6 NR2E3 NAALAD2 CDKN2B-AS1 NINJ2-AS1 NAALADL1 ACOT8 ABI1 GNPDA1 KCNE3 SNHG8 ZBTB33 CDH2 ZSCAN30 ANKRD26P3 TANK TMEM170B SRA1 HOTAIR ZGLP1 LEISA1 EGOT GHRLOS SLCO4A1-AS1 100127891 HECW1-IT1 100127974 C8orf88 WWTR1-AS1 FAM229A ST8SIA6-AS1 HGC6.3 STAU2-AS1 LINC01233 FHL1P1 UST-AS1 TMPO-AS1 LINC02731 ZNF667-AS1 LOC100128288 DLG5-AS1 100128343 LOC100128361 100128374 SCP2D1-AS1 NEBL-AS1 LINC02347 SLC8A1-AS1 ACVR2B-AS1 LOC100128653 OST4 100128737 RBPMS-AS1 100128751 ERCC6L2-AS1 SRRM2-AS1 LINC01003 100128840 VPS9D1-AS1 GATA6-AS1 100128909 ZBTB42 LINC01310 MAPT-AS1 LOC100128993 C20orf181 LOC100129034 LOC100129098 100129104 KHDC1L ZSCAN16-AS1 MATN1-AS1 SMIM27 KPLCE SMIM10L1 GABPB1-AS1 DDC-AS1 DPY19L1P1 100129463 100129476 ZNF37BP 100129502 LINC02035 LINC00462 LOC100129617 PLPPR5-AS1 TCF24 FGF13-AS1 ARRDC3-AS1 STPG3-AS1 LINC01126 CCDC152 SLC25A21-AS1 IRAG1-AS1 PCOLCE-AS1 SCOC-AS1 PCGF3-AS1 LOC100129935
Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64
1 6.1161 8.1556 9.7864 3.7977 3.5458 4.0034 4.2744 4.5413 9.1473 8.9353 10.329 3.8674 9.7509 7.2562 10.089 5.5023 5.1318 9.4746 7.0611 10.135 3.4237 5.3763 3.7262 3.928 4.8273 4.3633 5.1144 3.4459 3.3645 10.063 3.973 5.2733 3.9566 3.5376 4.357 3.5591 4.2723 3.4377 5.3373 3.5943 8.8792 4.9936 4.751 6.4005 5.0583 4.1019 4.5665 3.5261 3.7453 3.5599 4.7214 5.0406 12.83 7.1404 3.6834 4.0811 4.9718 4.6372 7.7981 6.3531 5.1322 6.1598 3.9735 4.9994 4.44 4.4209 3.8487 4.1256 4.7018 4.6271 4.7467 4.1358 4.1602 3.7541 5.4739 3.9598 11.178 7.1868 3.967 4.408 3.5832 3.5479 5.4789 6.6087 6.2223 3.4593 4.9019 3.5893 3.484 3.6189 5.2769 5.1277 4.6955 6.5938 4.5123 4.5795 5.9524 3.8132 5.1241 5.5822
2 6.1249 4.5676 8.872 3.8455 4.0458 4.1465 5.2087 4.2384 7.9144 9.2199 8.5751 3.9055 12.079 8.549 5.1431 6.8781 6.1015 8.1737 7.8928 9.1314 4.6211 5.2057 3.7446 3.7968 4.5537 4.3372 6.9432 3.4984 3.3535 4.8416 3.5889 5.4655 9.2539 3.8095 4.0741 3.203 4.3602 3.5055 5.3241 4.0011 7.0614 4.4428 4.7636 6.1365 5.4262 4.0945 4.6271 3.6436 3.6383 3.5107 4.6276 4.424 11.82 7.4321 4.6767 4.0793 5.8975 4.7025 10.227 4.2212 5.4238 6.4675 4.3708 6.925 4.4049 4.2476 4.01 4.3035 4.9296 4.5403 5.0117 4.0655 4.0542 3.7288 7.1523 3.6175 11.244 7.6113 3.9773 4.0322 4.2117 3.5356 6.4455 7.8265 6.6575 3.3085 4.8103 3.3011 3.8592 3.7392 5.0656 5.2868 4.3261 3.6788 5.9941 4.317 5.873 3.7362 6.6716 5.1789
3 5.4236 6.6723 10.029 3.6431 4.1589 3.9157 4.2522 4.5918 8.6817 8.9257 8.8995 4.1175 8.1509 8.2321 7.579 7.1193 5.0081 9.6532 4.5687 10.006 3.7612 4.8969 3.8591 3.7084 4.6726 4.5405 5.3275 3.6571 3.414 8.5376 3.5226 4.6851 4.274 3.5278 4.0956 3.3377 4.6155 3.4001 5.7089 4.0532 9.7671 4.1874 4.5761 6.4155 5.0162 4.0214 4.7451 3.6578 3.742 3.9166 4.2711 4.5772 12.445 6.2797 4.632 4.4335 5.8788 4.3411 6.5493 4.2998 5.1345 6.8207 4.1849 4.6936 4.3449 4.1936 4.1396 4.1467 5.0886 4.6539 4.6464 3.6874 4.5508 3.7116 5.9409 4.0153 11.511 7.3837 4.1537 4.162 3.5883 3.6235 5.5093 6.9003 4.8387 3.3809 4.7767 3.5132 3.3207 4.3073 5.2014 4.9938 4.3295 4.5384 4.0216 4.3298 6.1369 3.6992 4.8824 5.7028
4 6.0204 6.6316 9.6573 4.0679 4.1798 4.1234 4.3561 4.5595 8.2199 8.5833 10.132 3.976 10.374 8.1294 9.0702 5.2056 5.4952 8.8596 4.2959 9.9531 3.639 4.9425 3.825 3.9702 5.4899 4.2648 5.6399 3.4553 3.7317 5.2554 3.748 4.54 4.0113 4.1204 3.9046 3.4684 4.3954 3.442 6.6065 3.8596 6.3594 3.5475 4.8374 5.9462 5.4492 4.732 4.82 3.7227 3.4556 3.6903 4.5261 4.298 12.108 6.9702 5.2469 4.229 4.7916 4.4067 7.7524 4.1847 5.273 6.2075 4.4973 4.6357 4.3888 4.5425 3.897 4.1508 5.0305 4.5653 4.69 3.9433 3.8252 3.4472 6.8335 3.8566 10.815 6.7418 3.9824 3.9725 3.8947 3.8479 4.9489 6.5704 6.3452 3.5828 4.8131 3.6049 3.5085 3.9802 5.1819 5.0197 4.352 3.8799 4.5942 4.4457 5.5426 3.5416 4.0565 5.1455
5 6.3859 6.1898 9.491 3.8344 3.7039 4.4226 4.8496 4.3342 8.0724 8.1974 10.388 4.0512 10.356 8.8807 9.212 7.0572 5.0633 6.6393 5.9537 9.9459 4.4878 5.1016 4.21 3.9783 5.4788 4.2039 6.0553 3.9029 3.4559 5.1207 3.9242 4.6209 4.6432 3.8628 4.2352 3.4685 4.2888 3.509 5.5877 4.0277 4.2707 4.2317 4.742 6.4229 4.4815 3.9304 4.1617 4.1603 3.6254 3.6262 5.0518 4.2487 12.254 6.0167 4.1609 4.4752 4.6266 4.3501 8.3407 4.3564 6.501 5.8292 4.4509 6.0605 4.5449 4.4421 3.8683 3.8886 4.9649 4.6597 4.5869 4.3781 4.6019 3.3187 6.4998 3.715 11.273 7.1294 3.9285 4.3854 3.5855 3.4912 5.8339 7.2721 6.8048 3.6479 4.7623 3.5764 3.9569 3.9448 5.2733 5.1381 4.7019 3.8466 4.9038 4.3949 5.6212 3.62 4.5721 5.3539
6 5.5014 6.0786 9.5423 3.9649 3.618 4.4079 4.9635 4.5171 7.2326 9.9143 6.3459 4.1669 11.579 9.3338 7.6185 6.7434 7.8559 7.9548 4.8136 8.3389 3.4055 4.7611 3.6546 4.0848 4.955 4.1839 5.8716 3.5376 3.7119 8.9176 3.74 5.1276 4.4674 3.7224 4.6966 3.2913 4.3369 3.6119 7.2572 3.806 3.9563 4.6435 4.4606 6.8826 6.486 4.4873 4.7586 3.817 3.7074 4.5399 4.5499 4.4986 11.557 5.44 4.1374 7.4164 5.2231 4.852 8.8436 4.3674 5.3987 5.3187 5.1826 5.5305 4.559 4.7662 3.8173 4.4165 5.8759 4.9146 5.1458 3.7478 5.5752 3.4664 6.7346 4.0261 10.564 8.1501 4.0422 4.2668 3.6968 3.792 6.1195 4.5105 5.4984 3.8235 4.8534 3.5277 3.4264 3.7064 5.1453 5.2661 4.8622 3.7771 4.379 4.9467 6.1597 4.0439 4.7827 5.6685
7 5.4835 7.6018 8.7085 4.2291 3.6748 4.1059 4.0859 4.5479 7.7388 9.3113 10.133 4.0782 11.038 7.9564 10.707 6.8055 5.408 10.027 4.3944 10.253 3.6899 4.6924 4.3158 5.3718 4.6951 4.4006 5.6918 3.6535 3.6609 9.4175 3.7844 4.796 4.0363 4.2122 4.1001 3.2524 4.3868 3.6317 5.9104 3.7775 4.3765 3.6846 4.6243 6.4002 5.6815 3.8223 4.9337 3.8503 3.4977 3.891 4.2829 4.821 11.834 6.6773 5.216 4.3491 4.5173 4.3659 7.7802 4.2562 5.115 5.7325 4.2647 4.384 4.2938 4.3618 3.8942 4.0505 5.6235 4.6742 5.4005 3.8085 4.9127 3.3435 6.3765 3.6815 10.96 7.1816 4.3631 5.1085 3.9294 3.413 5.1001 5.9936 5.4163 3.5054 5.1398 3.4585 3.517 3.8722 5.03 5.3386 4.5751 3.9524 4.5659 4.887 5.6992 3.3689 5.3272 5.7659
8 5.8836 7.3755 8.4485 3.849 4.0474 4.43 4.4013 4.5049 8.1652 8.8486 8.9541 4.1214 11.162 7.9051 10.475 5.9131 5.2783 9.1245 5.0588 10.262 3.6931 4.6685 4.2913 6.0001 5.1113 4.7807 5.7171 3.7813 3.5521 9.5443 3.6621 5.2047 4.2524 3.9291 4.0613 3.5132 4.1861 3.5676 5.7339 4.12 7.6539 3.5461 4.6392 6.0882 4.8832 3.9143 4.6781 4.0023 3.7637 3.7407 4.4428 4.5183 12.197 6.9091 4.7225 4.0473 5.0616 4.3189 6.4187 4.3369 5.2484 5.9157 4.2163 3.8198 4.5477 4.6218 3.9468 4.0433 5.2692 4.5433 4.8544 4.2487 4.3062 3.595 6.4292 4.0548 10.446 7.062 4.3505 4.6618 3.9876 3.6008 5.5316 6.4282 5.2909 3.6069 5.0252 3.6493 3.8651 3.9967 5.0571 5.2783 5.0456 4.6675 4.2897 4.5997 5.4006 3.565 5.1458 5.3564
9 5.5245 7.9625 9.5389 3.9453 3.5372 4.1915 4.231 4.2448 8.9276 9.1479 9.8219 3.7767 9.8902 8.3222 10.799 7.1788 5.0411 8.7872 4.3292 9.8178 3.4009 4.8036 4.1656 3.9323 4.8063 4.287 5.5552 3.5627 3.5675 8.8571 3.7773 4.5267 4.1662 3.5956 4.0146 3.2117 4.2572 3.3202 5.2768 3.8651 7.9796 3.8133 4.5676 6.3454 5.4867 3.9702 4.6611 3.8717 3.5236 4.9398 4.4493 4.4374 12.325 6.4349 4.5995 4.2596 4.7972 4.4347 6.5447 4.3341 5.1158 7.6711 4.2157 5.0517 4.3432 4.1726 3.9325 3.8772 6.1171 4.5004 4.9366 3.8262 4.1725 3.6713 7.7556 4.1442 10.72 7.064 4.053 4.1109 3.5326 3.5374 4.7169 7.371 6.2346 3.3217 4.9663 3.531 3.4275 3.8517 4.8444 4.8865 4.124 3.6104 4.4616 4.3467 6.0667 3.7042 4.6092 5.385
10 6.0442 8.3002 9.626 4.0644 4.563 4.4893 4.3074 4.1737 8.2443 9.1017 9.7664 3.8597 9.2236 8.6561 9.0767 6.445 5.2137 8.1071 9.5007 9.5842 3.9439 4.6142 4.2479 3.8578 4.7559 4.2258 4.8096 3.4611 3.4731 7.1177 3.9429 4.6297 4.177 3.5544 4.1965 3.356 4.2834 3.5235 5.7703 3.9813 8.6195 3.4591 4.3816 6.5149 6.19 4.0799 4.29 3.8084 3.7841 3.7455 4.3723 4.5134 11.13 5.9132 4.3645 4.0891 5.0065 4.5753 7.4128 4.3072 4.8366 5.2223 4.3535 4.5689 4.2923 4.3603 3.874 3.8879 5.6755 4.3962 4.4962 4.2067 4.4029 3.6244 6.7478 4.1285 10.8 7.302 3.937 3.9239 3.698 3.5243 5.1349 6.9746 6.7283 3.3303 5.1772 3.5593 3.3614 3.7637 5.062 5.0505 4.2876 3.582 4.8211 4.4964 5.9722 3.4649 4.4958 5.149
11 6.1648 7.3721 9.2541 3.5979 4.2647 4.1867 4.1659 4.12 8.1798 8.6642 9.8387 4.2326 10.343 8.1016 8.6903 6.5453 5.0007 9.542 7.8213 9.4725 3.489 4.6426 3.7782 3.8895 4.6642 4.1233 5.6072 3.3313 3.3275 9.0579 3.8591 5.279 3.7498 3.6867 4.0037 3.351 4.174 3.3896 4.8248 3.6947 4.1055 4.734 4.4448 6.3388 4.9641 3.9699 3.7477 3.7594 3.6317 3.7562 4.8695 4.3093 12.237 6.2147 4.3179 4.4159 4.8229 4.3557 7.943 4.2854 4.9809 5.1397 4.3868 5.0686 4.2446 4.0976 3.5156 4.0008 4.5533 4.6447 3.8933 4.5286 4.8581 4.0552 6.5679 3.6998 11.317 7.3156 3.7446 5.0062 3.68 3.5114 5.1747 6.9665 5.4242 3.4345 4.7686 3.403 3.3188 3.8199 5.2564 4.6666 4.1161 6.1563 3.8298 4.39 5.573 3.4486 6.0491 5.5623
12 5.3746 8.5037 8.3209 3.6337 3.6266 4.1077 4.3816 4.4976 7.838 9.4221 9.3892 3.8328 9.7907 7.3467 11.16 6.9362 4.889 9.5011 7.1258 9.9211 7.1604 4.6713 3.7455 4.2327 4.6868 4.4218 5.2639 3.4584 3.374 8.4895 3.6976 4.5555 4.0298 3.8036 3.8337 3.5191 4.1059 3.4098 5.9848 4.0449 9.8274 3.6514 4.746 6.0108 5.1557 4.1726 4.5412 3.8236 3.5249 4.0093 4.4533 5.2112 12.519 6.2935 4.3585 4.0797 4.4245 4.748 8.0734 4.3845 5.097 5.7509 4.41 4.4832 4.2405 4.3223 3.882 4.0445 5.0711 4.819 4.4727 3.7674 4.6896 3.6335 6.8196 3.7932 10.156 6.9703 3.8319 4.3435 3.6421 3.5498 5.819 7.1839 7.7624 3.3384 4.8486 3.4989 3.4686 3.8839 5.1584 4.9571 4.5504 3.5909 4.2163 4.6009 5.9361 3.3776 4.8958 5.4451
13 5.6897 6.7056 9.5119 4.0123 4.4334 3.9062 4.3458 4.4989 7.1296 8.3905 10.413 4.005 11.173 7.5599 3.9705 6.8944 4.9484 8.7401 4.2915 9.336 3.6021 4.9194 3.9827 3.9386 4.7605 4.2541 5.7374 3.4084 3.4742 4.3754 3.977 4.7668 4.1114 3.943 3.9683 3.3383 4.5436 3.4332 6.6164 3.8849 9.3631 4.4557 4.9275 6.4371 5.1184 4.2081 4.5005 4.2016 3.4669 3.6979 4.2334 4.3193 11.968 5.6676 4.4577 4.2735 4.8164 4.4754 9.0617 4.3564 5.9564 5.5777 4.3372 5.1902 4.6391 4.4616 3.6337 4.3998 4.9305 4.5059 4.5866 4.1357 5.1252 4.1486 6.6499 3.8543 10.701 7.6851 3.9819 4.2283 4.1641 3.5393 5.2093 7.194 7.4725 3.4646 5.1281 3.3547 3.5565 3.9222 5.4935 5.2471 4.8605 4.0295 4.8949 4.4458 5.943 3.5596 5.8745 5.4694
463 5.7092 6.5837 9.0955 3.9468 3.8602 4.1066 4.5108 4.7847 8.2669 9.2121 10.51 4.1199 10.881 8.0359 6.329 5.7671 5.2092 8.2871 7.1378 9.4889 3.3681 5.4863 4.1983 3.9045 5.3372 4.3315 6.5879 3.5505 3.5124 8.8288 4.3802 4.9906 4.2706 3.9407 4.5925 3.3993 4.3152 3.624 6.5324 3.8006 4.4298 3.4789 4.7226 6.8767 5.3523 4.2144 4.9512 3.8739 3.4126 5.6796 4.0943 5.0692 11.47 7.4861 5.1599 4.6147 5.0398 4.8608 9.0209 4.1988 5.4306 5.7935 4.0254 5.1748 4.8543 4.4982 3.7537 4.5954 5.9438 4.6255 6.4698 4.0651 5.9295 3.7052 7.2602 4.2253 11.622 7.0922 4.3075 5.1781 4.0496 3.5406 5.5527 4.8361 5.1523 3.4538 5.521 3.498 3.4425 4.2184 5.7283 5.2056 4.8077 3.7291 5.6966 4.5277 6.2034 3.77 5.2553 5.4301
464 6.3627 5.055 7.4925 4.3273 3.7667 4.2062 4.3591 4.5839 8.3865 9.0613 10.827 4.216 12.143 7.6906 9.6873 7.2205 5.2902 9.6118 4.3345 9.4874 3.893 4.7277 4.4261 4.9 5.2271 4.4072 6.7739 3.6147 3.8608 4.598 4.5216 5.999 5.0435 3.8468 4.2393 3.365 4.5039 3.5603 4.4798 3.9911 5.6715 4.1453 4.7067 6.6441 4.9176 4.2775 5.2306 3.9882 3.621 3.6362 5.7148 5.2699 11.44 6.3111 5.1547 4.5655 4.899 5.3017 8.2587 4.3044 5.0529 6.9941 4.3349 4.2261 4.624 4.5207 3.9996 4.2335 6.4371 4.9134 4.4294 3.9179 5.3219 3.9154 6.8777 4.157 10.501 7.3214 4.4996 5.6908 4.0317 3.9353 5.4477 6.757 6.2008 3.4376 5.0388 3.5048 3.3915 4.0765 4.653 5.5624 5.4366 4.3615 4.8874 4.3694 6.3672 4.4959 4.4965 6.454
465 6.1159 6.2476 8.062 3.948 3.8791 3.8029 4.5129 4.5632 7.487 9.4722 10.621 3.9245 11.225 7.5705 11.102 7.3483 4.9168 9.3191 6.0023 9.1721 3.2558 5.1901 4.134 4.5762 4.8014 3.9384 5.9565 3.388 3.715 6.506 3.717 4.6896 4.1461 3.8237 4.1353 3.4076 4.4184 3.5163 6.1797 3.7423 4.5663 3.7534 4.682 6.3357 4.7794 3.8409 4.577 4.135 3.5722 3.7467 4.8766 4.7125 12.025 6.7574 5.5885 4.4133 4.3123 4.4277 7.913 4.2453 5.1399 5.6636 4.3018 4.5982 4.2402 4.1485 3.8966 4.3933 5.3111 4.5527 4.8254 3.6481 4.9435 3.6282 5.7926 3.8396 10.502 7.4907 3.859 6.9146 3.8559 3.2654 6.1796 4.9694 7.3062 3.3582 4.9865 3.5205 3.4504 3.8453 4.871 4.8457 5.086 3.7216 5.1717 4.2453 6.0332 3.7946 5.4044 5.9264
466 6.4481 7.8177 8.9042 3.8985 4.4229 4.0527 4.3046 4.3821 8.2329 7.9982 9.7372 4.1117 10.28 8.1062 9.8914 7.1196 4.7911 9.4343 5.8648 9.8341 3.9286 4.486 3.6659 3.8147 4.8087 5.0538 5.3632 3.6432 3.6516 7.476 3.6405 4.6023 3.9353 3.7066 3.8134 3.562 4.3237 3.507 5.9733 3.7697 4.3376 4.9795 4.5661 6.2244 6.2081 4.0921 4.4951 3.6425 3.555 3.6418 4.3434 4.7017 12.262 6.5302 4.8532 4.0085 6.1628 4.3507 8.0656 4.1119 5.1019 6.4051 4.3031 5.3716 4.3988 4.3118 3.9299 4.0491 6.0314 4.4482 4.5024 4.0026 4.9328 4.1857 5.2231 3.8343 11.16 7.38 4.0984 4.5687 3.6259 3.3974 5.8589 6.8065 7.1166 3.3489 4.6768 3.4681 3.5827 4.2294 4.7726 4.9103 4.955 3.6614 4.6436 4.2589 5.6849 3.4912 4.8652 5.5469
467 5.7485 8.0577 8.4808 3.8021 3.8744 4.2863 4.4439 4.3474 7.7764 9.0438 9.8451 4.0466 10.2 7.9222 9.822 5.7341 5.1891 8.3255 6.7218 9.1482 4.0746 4.5273 3.5876 3.9408 4.5713 4.8423 4.891 3.4227 3.4982 9.0153 3.9575 4.4551 4.1695 3.6492 4.1958 3.3314 3.918 3.4907 5.7787 3.7646 4.183 3.8971 4.6332 6.0931 5.7275 4.1238 4.6882 3.8635 3.7121 3.7569 4.4533 4.5603 11.696 6.1505 5.1999 3.9363 5.1409 4.4838 7.7513 4.3936 5.24 5.5251 4.366 4.8015 4.2841 4.35 4.1587 3.8482 5.0047 4.895 4.4813 3.7573 4.6402 3.8593 5.9285 3.8044 10.491 6.6187 3.9334 4.905 3.6997 3.5328 6.2624 8.8868 7.5715 3.5494 4.6424 3.6527 3.7407 4.4755 5.0819 4.7547 4.534 4.4275 4.148 4.5443 6.0176 3.4847 4.3378 5.2356
468 6.0815 7.9519 9.017 3.697 3.8336 3.8779 4.1023 4.5246 8.1685 8.9587 10.182 4.3557 10.416 7.8489 8.7747 6.1992 4.4995 8.5997 6.5782 8.9174 4.5248 4.708 3.9995 3.5434 4.61 6.5565 5.5625 3.5176 3.43 8.174 3.727 4.9125 4.1505 3.5497 3.6723 3.3354 4.1883 3.4416 5.874 3.6951 4.1043 4.4704 4.5908 6.1558 5.3038 4.0799 4.3843 3.7332 3.562 3.5485 5.0457 4.3955 12.149 6.6134 5.9139 3.9688 4.5641 4.0573 8.2018 4.0815 5.0979 5.1524 3.9603 5.4433 4.2282 4.1821 3.7677 3.8971 5.4137 4.4085 4.4674 3.765 4.5595 4.4581 6.2808 4.0817 10.035 7.048 4.06 4.8823 3.7955 3.3891 6.2697 7.5944 6.6587 3.3768 4.6793 3.3623 3.4787 4.5429 5.0896 4.9082 4.5697 5.5532 4.2227 4.374 5.6721 3.4603 4.1698 5.4536
469 6.1474 7.04 7.8911 3.8684 3.7986 4.0845 4.2573 4.3768 7.9576 9.0389 9.1773 4.0378 10.178 7.8205 9.8608 6.2259 4.6127 9.2865 4.8456 9.0218 4.2863 4.4976 3.7377 3.7952 4.5141 4.3975 5.4225 3.6051 3.4655 8.0629 3.6264 4.5986 3.8195 3.4747 4.0656 3.4228 3.9367 3.5022 5.4736 3.7906 6.5085 3.4821 4.6804 6.2675 6.0726 3.7699 4.7827 3.506 3.6163 4.1733 4.2093 4.8824 11.907 6.9371 5.6792 4.0968 4.832 4.1498 7.6898 4.8675 5.2484 5.6257 4.1692 5.1468 4.1896 4.377 3.6948 4.0184 5.3277 4.4993 4.727 3.7303 4.3352 3.8744 5.5126 3.8386 10.609 6.7744 4.1259 4.6142 3.8585 3.7382 5.3481 7.0978 6.5485 3.4224 4.5759 3.6705 3.6132 4.3314 5.0173 4.8135 4.471 3.966 4.3127 4.4272 5.9885 3.4618 4.1046 5.4713
470 5.9327 8.1834 8.8885 3.7437 3.7591 4.774 4.1942 4.3919 8.425 8.6307 9.8686 4.0916 10.633 8.5623 9.5042 6.7256 5.4459 9.0094 7.2605 8.8865 3.8541 4.7884 3.9117 3.8366 4.4337 6.4568 5.1242 3.457 3.5315 8.1262 3.6543 4.6987 4.4134 3.5322 4.2843 3.243 4.2487 3.371 4.645 3.6265 4.7463 4.3593 4.6095 6.4376 5.4419 3.9337 4.7793 3.5313 3.7519 3.8877 4.4551 5.0379 11.556 7.1897 5.449 3.9573 5.9204 4.244 9.5034 4.1902 5.4291 5.5442 4.3069 5.0109 4.553 4.3521 4.1869 3.8195 5.3393 4.8891 4.5034 3.711 4.6422 4.1259 6.6305 4.244 10.632 7.1981 4.1605 4.2457 3.5216 3.6051 6.2162 7.2257 6.0213 3.4824 4.7837 3.4725 3.5287 4.0204 4.8893 5.0123 4.652 4.0954 4.0222 4.7723 6.0141 3.74 4.6564 5.6749
471 5.5104 7.215 8.68 3.7454 4.522 3.9666 4.3686 4.6751 7.8968 9.0742 9.9636 4.0602 10.202 7.9502 8.8287 5.5097 5.8202 9.2068 7.1055 9.4849 3.7592 4.9997 4.2328 3.728 4.4741 4.9414 4.6753 3.6214 3.7421 9.7772 4.6411 4.1707 4.0703 4.3055 4.0835 3.4909 4.2999 3.6412 4.9072 3.7685 4.4789 3.841 4.3931 5.8953 5.1688 3.945 5.3225 3.5409 3.625 3.7327 4.1792 4.9069 11.544 6.1988 4.8971 4.484 5.5002 4.4576 8.2406 4.7965 4.71 5.988 3.8857 4.8511 4.7713 4.4335 3.9843 4.1317 5.0714 4.5588 4.7465 3.9952 5.7785 3.9779 5.4691 3.6872 11.208 7.1848 4.0337 4.1131 4.4716 3.8865 6.277 6.6933 7.0644 3.4348 5.2539 3.4277 3.4851 4.4239 5.4276 5.5385 4.7416 3.5912 4.2432 4.9502 5.9118 3.7427 4.8633 5.8418
472 5.4516 7.0296 8.3857 3.5973 3.8004 5.4007 4.2063 4.3827 8.2802 8.951 10.429 4.1597 8.6804 9.1192 9.0111 6.4356 5.0376 9.3605 7.8224 9.2388 4.8506 4.6858 3.7146 3.8078 4.6448 4.2302 6.5926 3.4351 3.5319 7.4328 3.8616 4.5879 3.999 3.4038 3.8842 3.4164 4.0668 3.3894 5.19 3.743 9.1845 4.1761 4.687 5.9779 5.0994 3.9704 4.4917 3.8258 3.6049 5.1091 4.5177 4.4416 12.228 6.4303 4.0911 4.1215 5.1991 4.4001 10.402 4.2006 5.1052 6.9076 4.1533 4.5977 4.2917 4.2241 3.7199 3.8218 5.4725 4.4122 4.1835 3.7509 4.8796 3.5679 6.5358 3.3856 10.967 7.2068 4.0016 5.157 3.7277 3.2806 5.0563 6.5619 8.4933 3.4864 4.5978 3.4561 3.4442 3.9972 5.1713 4.8468 4.0157 3.7829 3.7782 4.2557 6.0229 3.6288 4.4508 5.2954
473 5.2289 5.896 9.3064 4.709 3.6699 4.0059 4.7178 4.634 8.3685 9.5499 8.8421 3.997 11.185 8.2957 4.4075 6.3168 4.3967 10.312 6.9335 8.6912 6.7969 4.6556 3.9075 4.3341 4.9324 4.1621 5.5006 3.6982 3.824 4.6893 3.6952 4.9736 9.7636 3.7411 3.9814 3.6274 4.038 3.6303 5.5921 3.9257 3.9401 3.7839 4.9651 6.5211 5.0774 4.2836 4.6332 5.0101 3.3001 3.6958 4.2993 4.8415 12.197 6.2062 5.2149 4.223 4.951 4.5411 9.4794 4.3088 5.3704 5.7951 4.3287 8.0867 4.4914 4.7177 3.8692 4.1722 5.0371 4.6583 5.0509 4.0571 3.9026 4.2794 7.5338 4.3254 11.546 7.2682 4.1015 4.8601 4.1702 3.9545 5.5959 6.833 6.4657 3.3623 4.9573 3.5597 3.5059 3.9 4.93 5.2107 4.5862 3.7172 4.6574 4.6345 5.9692 3.6375 5.2723 5.3535
474 5.2136 4.4004 9.0476 4.0903 3.8667 4.2915 5.3009 4.3956 7.2714 9.4498 9.7508 4.0178 10.658 8.5866 4.5808 5.8758 5.2512 9.3426 7.0326 8.6553 8.6477 4.6671 3.8731 4.5746 4.5458 4.1798 5.5136 3.4078 3.4167 4.5637 3.956 5.5805 10.512 3.9452 4.3274 3.3815 4.1937 3.4141 4.9435 3.6298 4.2023 3.4642 4.6599 6.6533 5.519 4.0032 4.4507 4.7093 3.8405 3.5825 4.8181 4.8643 12.317 5.1039 4.2687 4.1118 4.6158 4.37 8.9364 4.4455 5.2304 5.5198 4.5232 7.3727 4.4341 4.3096 3.9733 4.1602 5.5964 4.6925 5.2418 3.9793 3.9017 4.0007 6.1588 3.9591 11.497 6.0464 4.3425 4.0336 3.8187 3.6115 5.2116 8.397 6.5865 3.6529 4.9656 3.4065 3.5346 3.9934 4.7451 5.1568 4.768 3.6576 4.7356 4.512 5.894 3.5879 5.7236 5.8015

Since so many genes are profiled in a relatively small number of samples, many will not show much variation:

using Plots
using Statistics
sd = map(x -> std(x), eachcol(X))
histogram(sd; xlabel="Standard deviation", label="")

Therefore it is common to keep only variable genes (for instance, using a cutoff on the standard deviation) for downstream analysis. However, if we simply remove genes with low standard deviation at this point, we enter the leaky preprocessing terrain, because information from all samples has been used to compute the standard deviations.

The common solution is to first split data in training, validation, and test datasets, and then do the filtering on the training data alone. This is messy and prone to errors because one has to keep track manually of the different datasets.

A much better solution is to create a linear pipeline composed of two models: a feature selection model and a regression model.

The feature selection model

We want to keep the standard deviation threshold as a parameter of the feature selection model, which does not seem possible with MLJ’s built-in FeatureSelector. Hence we define a new unsupervised model:

using MLJ
import MLJModelInterface
MLJModelInterface.@mlj_model mutable struct FeatureSelectorStd <: Unsupervised
    threshold::Float64 = 1.0::(_ > 0)
end
    
function MLJ.fit(fs::FeatureSelectorStd, verbosity, X)
    # Find the standard deviation of each feature
    stds = map(x -> std(x), eachcol(X))
    selected_features = findall(stds .> fs.threshold)
    cache = nothing
    report = nothing
    return selected_features, cache, nothing
end

function MLJ.transform(::FeatureSelectorStd, fitresult, X)
    # Return the selected features
    return selectcols(X,fitresult)
end

A trivial but nonetheless important point is that the fit function identifies the features with standard deviation greater than the threshold from its argument \(X\), while the transform function applies the feature selection to a possibly different \(X'\), irrespective of the standard deviation values in that \(X'\). To illustrate this, let’s fit a feature selector on all but the first cell line, and use it to transform the left-out one. First we create a machine that binds our feature selector (with default standard deviation threshold of 1) to the selected samples, then we fit the machine to the data (that is, compute the standard deviation of each feature and find the ones exceeding the threshold), and finally transform (i.e., select the relevant features from) the held-out sample:

mfs = machine(FeatureSelectorStd(),selectrows(X,2:nrow(X)));
fit!(mfs)
MLJ.transform(mfs,selectrows(X,1))
[ Info: Training machine(FeatureSelectorStd(threshold = 1.0), …).
1×4029 DataFrame
3929 columns omitted
Row AKT3 GNPDA1 KCNE3 CDH2 TMEM170B HOTAIR C8orf88 ST8SIA6-AS1 ZNF667-AS1 LINC01003 100129502 GPC1-AS1 100130458 AGAP2-AS1 100130938 LINC02901 TSTD1 CCDC18-AS1 IPO5P1 MSC-AS1 PEG3-AS1 SH2B3 CDH3 GNE SMKR1 LOXL1-AS1 100287628 100288092 ERVMER34-1 CRIM1-DT LINC03011 MCF2L-AS1 ZNF605 LINC00942 CDH4 CDH5 TOM1L1 SH2D3A MAMLD1 LINC00673 CDH6 100505481 PRKAG2-AS1 100505491 LOC100505498 TOX-DT ATP8B1-AS1 GAPLINC LINC01133 100505634 100505650 SNHG33 LOC100505715 100505730 100505760 C10orf95-AS1 SNHG18 LINC01088 100505880 MAGI2-AS3 TMEM161B-DT ICA1-AS1 100505946 100505971 SMIM31 GIRGL LINC01503 100506130 TMEM35B KRBOX1 100506262 HOTAIRM1 GPRC5D-AS1 SP2-AS1 100506377 SLC16A1-AS1 LINC00648 LINC01234 DSCAM-AS1 LINC02274 100506676 100506687 LINC02985 PCAT6 100506718 TRG-AS1 TSPOAP1-AS1 HOXD-AS2 FOXP1-IT1 100506828 GIHCG DAPK1-IT1 MAP3K2-DT 100506941 100506948 PWAR6 100507025 100507039 100507263 100507309
Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64 Float64
1 8.1556 10.329 3.8674 10.089 7.0611 3.4237 10.063 3.9566 8.8792 7.7981 6.6087 3.7708 3.662 7.9036 4.4216 4.0244 3.9195 7.6389 7.0207 3.4919 3.4649 8.4438 5.6715 9.0013 6.8019 8.1527 5.4082 7.1847 3.8176 7.6124 8.3773 4.3925 7.4601 4.3373 4.1995 4.0629 9.019 3.9668 5.9541 4.9314 3.7849 5.377 5.2187 5.1565 3.9566 7.6923 5.2297 4.071 10.848 6.4075 4.1035 8.3257 9.1295 5.9321 3.7539 4.2855 4.0455 4.0169 4.2539 7.3762 5.0813 4.7224 7.3707 6.8233 3.5177 5.0005 8.7873 6.837 9.603 8.2868 4.694 6.0948 3.5209 6.7183 6.2825 7.9747 4.8917 8.4549 4.0568 3.6044 8.5322 5.3873 4.6823 5.507 7.6599 3.6782 3.7272 4.3594 6.128 7.1395 9.5547 6.1792 7.2253 4.9493 6.9911 6.4198 5.1136 3.3619 8.036 3.8194

Creating a feature selection model with a different standard deviation threshold is as simple as

FeatureSelectorStd(threshold=2.0)
FeatureSelectorStd(
  threshold = 2.0)

The regression model

To predict drug sensitivity from gene expression data, we will use random forest regression. The model is imported as follows:

using MLJDecisionTreeInterface
RandomForestRegressor = @load RandomForestRegressor pkg="DecisionTree";
[ Info: For silent loading, specify `verbosity=0`. 
import MLJDecisionTreeInterface ✔

The pipeline model

A linear pipeline model that first performs feature selection and then random forest regression with default parameter values is easily constructed:

fs_rf_model = FeatureSelectorStd() |> RandomForestRegressor()
DeterministicPipeline(
  feature_selector_std = FeatureSelectorStd(
        threshold = 1.0), 
  random_forest_regressor = RandomForestRegressor(
        max_depth = -1, 
        min_samples_leaf = 1, 
        min_samples_split = 2, 
        min_purity_increase = 0.0, 
        n_subfeatures = -1, 
        n_trees = 100, 
        sampling_fraction = 0.7, 
        feature_importance = :impurity, 
        rng = Random.TaskLocalRNG()), 
  cache = true)

A pipeline model can be fitted and evaluated like any other MLJ model, see the common MLJ workflows.

fs_rf_mach = machine(fs_rf_model,X,y)
train, test = partition(1:nrow(X), 0.8, shuffle=true)
fit!(fs_rf_mach, rows=train)
yhat = predict(fs_rf_mach, X[test,:])
scatter(y[test], yhat; label="", xlabel="True response", ylabel="Predicted response")
[ Info: Training machine(DeterministicPipeline(feature_selector_std = FeatureSelectorStd(threshold = 1.0), …), …).
[ Info: Training machine(:feature_selector_std, …).
[ Info: Training machine(:random_forest_regressor, …).

By incorporating data preprocessing steps in the ML model, we avoid having to remember to split the data before any preprocessing, giving a lot more flexibility in the analysis and avoiding data leakage. The real beauty though is perhaps that the preprocessing parameters (here, standard deviation threshold) become learnable hyperparameters of the combined model! In other words, instead of setting preprocessing parameters to some arbitrary value without clear justification, as is common in bioinformatics pipelines, we can now tune their value to the learning goal (finding the best model to predict drug sensitivity from gene expression).

As an example, we tune the standard deviation threshold using grid search and 5-fold cross-validation on the training data, while keeping the default parameters for the random forest regressor (following this tutorial):

r = range(fs_rf_model, :(feature_selector_std.threshold), lower=0.5, upper=3.0)
tuned_fs_rf_model = TunedModel(
    model=fs_rf_model,
    resampling=CV(nfolds=5),
    tuning=Grid(resolution=20),
    range=r,
    measure=rms
)
tuned_fs_rf_mach = machine(tuned_fs_rf_model, X, y)
fit!(tuned_fs_rf_mach, rows=train)
fitted_params(tuned_fs_rf_mach)[:best_model]
[ Info: Training machine(DeterministicTunedModel(model = DeterministicPipeline(feature_selector_std = FeatureSelectorStd(threshold = 1.0), …), …), …).
[ Info: Attempting to evaluate 20 models.

Evaluating over 20 metamodels:  10%[==>                      ]  ETA: 0:03:29
Evaluating over 20 metamodels:  15%[===>                     ]  ETA: 0:03:05
Evaluating over 20 metamodels:  20%[=====>                   ]  ETA: 0:02:48
Evaluating over 20 metamodels:  25%[======>                  ]  ETA: 0:02:14
Evaluating over 20 metamodels:  30%[=======>                 ]  ETA: 0:02:02
Evaluating over 20 metamodels:  35%[========>                ]  ETA: 0:01:45
Evaluating over 20 metamodels:  40%[==========>              ]  ETA: 0:01:30
Evaluating over 20 metamodels:  45%[===========>             ]  ETA: 0:01:20
Evaluating over 20 metamodels:  50%[============>            ]  ETA: 0:01:09
Evaluating over 20 metamodels:  55%[=============>           ]  ETA: 0:01:10
Evaluating over 20 metamodels:  60%[===============>         ]  ETA: 0:01:01
Evaluating over 20 metamodels:  65%[================>        ]  ETA: 0:00:50
Evaluating over 20 metamodels:  70%[=================>       ]  ETA: 0:00:41
Evaluating over 20 metamodels:  75%[==================>      ]  ETA: 0:00:33
Evaluating over 20 metamodels:  80%[====================>    ]  ETA: 0:00:27
Evaluating over 20 metamodels:  85%[=====================>   ]  ETA: 0:00:20
Evaluating over 20 metamodels:  90%[======================>  ]  ETA: 0:00:15
Evaluating over 20 metamodels:  95%[=======================> ]  ETA: 0:00:07
Evaluating over 20 metamodels: 100%[=========================] Time: 0:02:23
DeterministicPipeline(
  feature_selector_std = FeatureSelectorStd(
        threshold = 1.8157894736842106), 
  random_forest_regressor = RandomForestRegressor(
        max_depth = -1, 
        min_samples_leaf = 1, 
        min_samples_split = 2, 
        min_purity_increase = 0.0, 
        n_subfeatures = -1, 
        n_trees = 100, 
        sampling_fraction = 0.7, 
        feature_importance = :impurity, 
        rng = Random.TaskLocalRNG()), 
  cache = true)

We find that the best standard deviation threshold is 1.8157894736842106. The average RMS at each threshold can be plotted:

plot(tuned_fs_rf_mach)

Linear pipelines can obviously contain multiple preprocessing steps. For instance, here is a pipeline that first performs variable gene selection followed by standardizing each gene to mean 0 and standard deviation one, followed by a regression model1:

Standardizer = @load Standardizer pkg=MLJModels
FeatureSelectorStd() |> Standardizer() |> RandomForestRegressor()
[ Info: For silent loading, specify `verbosity=0`. 
import MLJModels ✔
DeterministicPipeline(
  feature_selector_std = FeatureSelectorStd(
        threshold = 1.0), 
  standardizer = Standardizer(
        features = Symbol[], 
        ignore = false, 
        ordered_factor = false, 
        count = false), 
  random_forest_regressor = RandomForestRegressor(
        max_depth = -1, 
        min_samples_leaf = 1, 
        min_samples_split = 2, 
        min_purity_increase = 0.0, 
        n_subfeatures = -1, 
        n_trees = 100, 
        sampling_fraction = 0.7, 
        feature_importance = :impurity, 
        rng = Random.TaskLocalRNG()), 
  cache = true)

A learning network: supervised gene selection

In the previous example we selected genes based on their variability. But since we’re ultimately interested in building a predictive model for the response variable (drug sensitivity), would it not make more sense to select genes based on their correlation with the response? In lasso regression, features are guaranteed to be absent in the optimal model if their correlation with the response is sufficiently small (so-called safe feature elimination). In genomics, when constructing polygenic scores, it is common to select features (in this case, SNPs) based on their effect size (that is, correlation with the outcome) in the corresponding GWAS. According to Barnett et al (2022), using the entire dataset for GWAS and this kind of feature selection is the most common form of data leakage in this area.

A feature selection step that depends on the target variable \(y\) cannot be put in a linear pipeline, because at most one the pipeline components may be a supervised model2. What we need instead is a workflow like this:

Learning network diagram

Learning networks are exactly what is needed to implement such a model.

The supervised feature selection model

First we construct our feature selector:

MLJModelInterface.@mlj_model mutable struct FeatureSelectorCor <: Unsupervised
    threshold::Float64 = 0.1::(_ > 0)
end

function MLJ.fit(fs::FeatureSelectorCor, verbosity, X, y)
    # Find the correlation of each feature with y
    cors = map(x -> abs(cor(x,y)), eachcol(X))
    selected_features = findall(cors .> fs.threshold)
    cache = nothing
    report = nothing
    return selected_features, cache, nothing
end

function MLJ.predict(::FeatureSelectorCor, fitresult, X, y)
    # Return the selected features
    return selectcols(X,fitresult)
end

function MLJ.transform(::FeatureSelectorCor, fitresult, X)
    # Return the selected features
    return selectcols(X,fitresult)
end

The learning network

Now we construct the learning network, pretty much following the documentation.

First the data source nodes:

Xs = source(X);
ys = source(y);

Now the feature selector machine and the transformed feature data node:

mach_fs = machine(FeatureSelectorCor(threshold=0.2),Xs, ys);
Xt = MLJ.transform(mach_fs,Xs)
Node @636 → FeatureSelectorCor(…)
  args:
    1:  Source @163
  formula:
    transform(
      machine(FeatureSelectorCor(threshold = 0.2), …), 
      Source @163)

Then the random forest regression machine and the predicted response node:

mach_rf = machine(RandomForestRegressor(), Xt, ys)
yh = predict(mach_rf,Xt)
Node @587 → RandomForestRegressor(…)
  args:
    1:  Node @636 → FeatureSelectorCor(…)
  formula:
    predict(
      machine(RandomForestRegressor(max_depth = -1, …), …), 
      transform(
        machine(FeatureSelectorCor(threshold = 0.2), …), 
        Source @163))

To fit the learning network on the training samples, we call fit! on the output node, which triggers training of all the machines on which it depends:

fit!(yh, rows=train)
[ Info: Training machine(FeatureSelectorCor(threshold = 0.2), …).
[ Info: Training machine(RandomForestRegressor(max_depth = -1, …), …).
Node @587 → RandomForestRegressor(…)
  args:
    1:  Node @636 → FeatureSelectorCor(…)
  formula:
    predict(
      machine(RandomForestRegressor(max_depth = -1, …), …), 
      transform(
        machine(FeatureSelectorCor(threshold = 0.2), …), 
        Source @163))

The fitted model can be compared to the true values on the test samples using the syntax yh(rows=test); the RMSE in this case is 0.9683378174604053

scatter(y[test],yh(rows=test), label="", xlabel="True response", ylabel="Predicted response")

Exporting the learning network as a new model type

A learning network must be exported as a new model type to allow hyperparameter tuning. This has the additional benefit that the feature selector and regression model can easily be swapped for alternative models. We follow the documentation and first define a new composite model type with the same supertype as a random forest regressor:

supertype(typeof(RandomForestRegressor()))
Deterministic

Our new composite model consists of a feature selector and a regressor:

mutable struct CompositeFS <: DeterministicNetworkComposite
    featureselector
    regressor
end

No we need to make our learning network generic and wrap it in a prefit method:

import MLJBase
function MLJBase.prefit(composite::CompositeFS, verbosity, X, y)

    # the data source nodes
    Xs = source(X)
    ys = source(y)

    # the supervised feature selector
    mach_fs = machine(:featureselector, Xs, ys);
    Xt = MLJ.transform(mach_fs,Xs)

    # the regressor on the preprocessed data
    mach_regr = machine(:regressor, Xt, ys)
    yhat = predict(mach_regr, Xt)

    verbosity > 0 && @info "I'm a learning network"

    # return "learning network interface":
    return (; predict=yhat)
end

An instance of the new model type is created and fitted using the standard interface

composite_fs_model = CompositeFS(FeatureSelectorCor(), RandomForestRegressor())
mach_composite_fs = machine(composite_fs_model, X, y)
fit!(mach_composite_fs, rows=train)
yhat = predict(mach_composite_fs, X[test,:])
scatter(y[test], yhat; label="", xlabel="True response", ylabel="Predicted response")
[ Info: Training machine(CompositeFS(featureselector = FeatureSelectorCor(threshold = 0.1), …), …).
[ Info: I'm a learning network
[ Info: Training machine(:featureselector, …).
[ Info: Training machine(:regressor, …).

Tuning the preprocessing hyperparameter

Using the new model type, we can tune the correlation threshold, keeping the default random forest hyperparameters:

r = range(composite_fs_model, :(featureselector.threshold), lower=0.05, upper=0.35)
tuned_composite_fs_model = TunedModel(
    model=composite_fs_model,
    resampling=CV(nfolds=5),
    tuning=Grid(resolution=20),
    range=r,
    measure=rms
)
tuned_composite_fs_mach = machine(tuned_composite_fs_model, X, y)
fit!(tuned_composite_fs_mach, rows=train)
fitted_params(tuned_composite_fs_mach)[:best_model]
[ Info: Training machine(DeterministicTunedModel(model = CompositeFS(featureselector = FeatureSelectorCor(threshold = 0.1), …), …), …).
[ Info: Attempting to evaluate 20 models.

Evaluating over 20 metamodels:  10%[==>                      ]  ETA: 0:01:35
Evaluating over 20 metamodels:  15%[===>                     ]  ETA: 0:01:47
Evaluating over 20 metamodels:  20%[=====>                   ]  ETA: 0:01:26
Evaluating over 20 metamodels:  25%[======>                  ]  ETA: 0:01:12
Evaluating over 20 metamodels:  30%[=======>                 ]  ETA: 0:01:06
Evaluating over 20 metamodels:  35%[========>                ]  ETA: 0:01:03
Evaluating over 20 metamodels:  40%[==========>              ]  ETA: 0:00:58
Evaluating over 20 metamodels:  45%[===========>             ]  ETA: 0:01:03
Evaluating over 20 metamodels:  50%[============>            ]  ETA: 0:01:09
Evaluating over 20 metamodels:  55%[=============>           ]  ETA: 0:01:14
Evaluating over 20 metamodels:  60%[===============>         ]  ETA: 0:01:03
Evaluating over 20 metamodels:  65%[================>        ]  ETA: 0:01:18
Evaluating over 20 metamodels:  70%[=================>       ]  ETA: 0:01:10
Evaluating over 20 metamodels:  75%[==================>      ]  ETA: 0:00:59
Evaluating over 20 metamodels:  80%[====================>    ]  ETA: 0:00:51
Evaluating over 20 metamodels:  85%[=====================>   ]  ETA: 0:00:37
Evaluating over 20 metamodels:  90%[======================>  ]  ETA: 0:00:24
Evaluating over 20 metamodels:  95%[=======================> ]  ETA: 0:00:12
Evaluating over 20 metamodels: 100%[=========================] Time: 0:03:46
CompositeFS(
  featureselector = FeatureSelectorCor(
        threshold = 0.2710526315789474), 
  regressor = RandomForestRegressor(
        max_depth = -1, 
        min_samples_leaf = 1, 
        min_samples_split = 2, 
        min_purity_increase = 0.0, 
        n_subfeatures = -1, 
        n_trees = 100, 
        sampling_fraction = 0.7, 
        feature_importance = :impurity, 
        rng = Random.TaskLocalRNG()))
plot(tuned_composite_fs_mach)

Footnotes

  1. This is of course a hypothetical example, a random forest regressor does not need any standardization or other type of scaling, because it only cares about the order of values of a feature, not their absolute value.↩︎

  2. The pipeline documentation states that “Some transformers that have type Unsupervised (so that the output of transform is propagated in pipelines) may require a target variable for training […]. Provided they appear before any Supervised component in the pipelines, such models are supported.” In theory this should apply to a FeatureSelectorCor, but including one in a linear pipeline gives an error during fitting. Maybe I’ve not understood how to this properly?↩︎