Skip to content

Imputation/SRMI

What Is It

Suppose you have some variable (\(y\)) that has missing observations (some people didn't respond to the survey or data was left blank in the administrative record, etc.). You also have some set of characteristics that are observable, with no missing information (\(X\)). The basic idea of imputation is to estimate the distribution of \(y\) conditional on \(X\), or \(f(y|X)\) and then use that estimate to draw plausible values of \(y\) for each missing observation so that you can get an estimate of some statistic of \(Q(y)\) (the mean, the median, a regression coefficient that conditions on \(y\), ...). The imputation model must be "congenial" to estimate a given statistic (or "proper"), which means the imputes should be drawn such that \(Q(\hat{y})\) is unbiased and the estimates have valid confidence intervals (i.e. uncertainty is accounted for correctly - see van Buren, 2018 for a much more accurate and thorough discussion).

When actually doing the imputation you have several choices:

  1. How will you estimate \(f(y|X)\)? - What functional form will you use and how will you estimate the parameters? A regression, a hot deck, machine learning? Embedded in this is what do you include in \(X\).
  2. How will you draw values of \(y\) given your estimate of \(f(y|X)\)? - draw from the error distribution from your regression (\(\hat{e}\)), use the emprirical distribution of observed values with the same/similar expected values (\(\hat{y}\)) as in predicted mean matching or a hot deck, etc.
  3. How will you account for the relationship between imputed variables in your model? Suppose you have several variables with missing values, such that you want to estimate \(f(y_1|X, y_2)\) and \(f(y_2|X,y_1)\), how will you account for the relationship between \(y_1\) and \(y_2\)? This is what Sequential Regression Multivariate Imputation (SRMI) is for. You iteratively run the estimation, where in the first iteration \(^{(1)}\), you estimate \(f(y_1^{(1)}|X)\) (ignoring the relationship between \(y_1\) and \(y_2\)), then impute \(f(y_2^{(1)}|X,y_1^{(1)})\) (not ignoring the relationship between \(y_2\) and \(y_1\)). Then, in the second iteration \(^{(2)}\), you estimate \(f(y_1^{(2)}|X,y_2^{(1)})\), with the imputed \(y_2\) from the prior iteration to incorporate the relationship between the two \(y\) variables. You then re-impute \(y_2^{(2)}\) with the newly imputed values for \(y_1^{(2)}\). You continue with additional iterations (up to \(k\), for example) until you have converged to the covariate distributions of \(f(y_1^{(k)}|X,y_2^{(k-1)})\)=\(f(y_1|X,y_2)\) and \(f(y_2^{(k)}|X,y_1^{(k)})\)=\(f(y_2|X,y_1)\). Many imputations (such as most done at the U.S. Census Bureau) do not do this and stop after the first iteration, where \(f(y_1^{(1)}|X)\) was estimated without conditioning on \(y_2\) (or other variables with missing information).
  4. How do account for uncertainty? You should use multiple imputation (Rubin, 1987), which means that you take some number of independent draws of \(y\) for each missing observation. This allows you to account for uncertainty in \(f(y|X)\). However, most estimation methods involve estimating \(f(y|X,\theta)\), where \(\theta\) could be the regression coefficients from an OLS regression. Since you are estimating \(\hat{\theta}\), the estimates have uncertainty as well, which should be incorporated into the imputation. This package handles this through taking a Bayes Bootstrap of the sample at each step of the estimation.

The tutorials below show how you can make some of these choices. For example, by choosing the imputation model type and the method for drawing imputed values, you are choosing \(f\). By specifying the imputation formula (either as an R-style formula or a list of variables names), you are specifying \(X\). By setting the number of iterations (n_iterations), you are choosing to/not to do SRMI. By setting the number of implicates (n_implicates), you choose whether to do single/multiple imputation.

Why Use It

This package was built to handle production-scale multiple imputation for complex survey data. It provides capabilities that go beyond typical imputation packages:

  • Speed and Scale - Built on polars to handle large datasets efficiently. Impute datasets with hundreds of thousands or millions of observations on a standard laptop in minutes rather than hours. You can pass data in any form supported by Narwhals and you'll get your data back in that form (Pandas, Polars, DuckDB, PyArrow, etc.). We just use polars under the hood for speed.

  • Production-Grade Reliability - Designed for long-running imputation pipelines with managed state and checkpointing. If your imputation is interrupted (power outage, server maintenance, hitting time limits, you want to use your laptop to play a game), you can resume exactly where you left off without losing work.

  • Flexible Method Mixing - Different variables need different approaches. Impute demographic variables via hot deck, income with gradient boosting, and binary indicators with PMM—all in one coherent framework with proper iteration.

  • Complex Dependency Handling - Real survey data has intricate relationships. For example, you may want to:

    • Impute an earnings flag first, then earnings source conditional on that flag (primary earnings from wage and salary, self-employment, or farm self-employment)
    • Impute by subgroups (wage earners, self-employed, farm self-employed) based on previously imputed variables
    • Impute spouse 1's income using spouse 2's, then vice versa, then recalculate household totals to use to impute other variables (like interest income, pensions, etc.)

    This package handles this naturally without requiring you to reshape your data or run totally separate imputation models (which would make it impossible to use SRMI properly).

  • Arbitrary Pre/Post Processing - Execute custom functions at any point in the imputation sequence. Recalculate derived variables, apply business rules, update relationships—- whatever your data requires. Keep your data in its natural structure (person-level, household-level) and use the flexibility of the package make that complexity easy to manage (it's easy to write a function that updates spousal earnings and just call that when you need it).

  • Modern Methods - Native support for gradient boosting (LightGBM) with hyperparameter tuning, quantile regression for continuous variables, etc. Methods typically unavailable in other imputation packages.

Other tools can be integrated into this package, but the default is LightGBM for its speed and accuracy (some other tools include XGBoost, Catboost, random forest).

Tested at Scale

This package integrates knowledge gained from years of research at the U.S. Census Bureau:

It was designed to handle imputing hundreds of variables across multiple iterations (SRMI) with samples ranging from hundreds of thousands to hundreds of millions of records, with complex dependency structures between variables.

When to Use This vs. Other Packages

Use this package when you:

  • Work with large datasets (>100K rows) where performance matters
  • Need production reliability with checkpoint/resume capability
  • Want to mix different imputation methods intelligently
  • Have complex variable dependencies that require custom logic
  • Need modern ML approaches (gradient boosting, quantile regression)
  • Require hot deck or statistical matching methods

Use mice or similar packages when you:

  • Work with smaller datasets (<100K rows)
  • Want extensive built-in convergence diagnostics
  • Just want something that works and you can cite
  • Want a simpler API with simpler defaults

Note: This package assumes familiarity with imputation methodology. It provides powerful, flexible tools for implementing complex imputation strategies correctly at scale. If you need a point-and-click solution with extensive guardrails, traditional packages may be more appropriate.

API

See the full Imputation/SRMI API documentation

Examples/Tutorials

Stat match uses a join and hot deck fill forward from an array across the file, but there is no real difference between them theoretically

import sys
import os
from pathlib import Path

import narwhals as nw
import polars as pl
import polars.selectors as cs

from survey_kit.utilities.random import RandomData
from survey_kit.utilities.dataframe import summary

from survey_kit.imputation.variable import Variable
from survey_kit.imputation.parameters import Parameters
from survey_kit.imputation.srmi import SRMI
from survey_kit.orchestration.config import Config

from survey_kit import logger, config
from survey_kit.utilities.dataframe import summary, columns_from_list



# %%
# Draw some random data

n_rows = 10_000
impute_share = 0.25

df = (
    RandomData(n_rows=n_rows, seed=32565437)
    .index("index")
    .integer("year", 2016, 2020)
    .integer("month", 1, 12)
    .integer("var2", 0, 10)
    .integer("var3", 0, 50)
    .float("var4", 0, 1)
    .integer("var5", 0, 1)
    .np_distribution("epsilon_hd1", "normal", scale=5)
    .np_distribution("epsilon_hd2", "normal", scale=5)
    .float("missing_hd1", 0, 1)
    .float("missing_hd2", 0, 1)
    .to_df()
)


#   Convenience references to them for creating dependent variables
c_var2 = pl.col("var2")
c_var3 = pl.col("var3")
c_var4 = pl.col("var4")
c_var5 = pl.col("var5")

c_e_hd1 = pl.col("epsilon_hd1")
c_e_hd2 = pl.col("epsilon_hd2")


logger.info("var_hd1 is binary and conditional on other variables")
c_hd1 = ((c_var2 * 2 - c_var3 * 3 * c_var5 + c_e_hd1)  > 0).alias("var_hd1")

logger.info("var_hd2 is != 0 only if var_hd1 == True")
c_hd2 = (
    pl.when(pl.col("var_hd1"))
      .then(((c_var2 * 1.5 - c_var3 * 1 * c_var4 + c_e_hd2)))
      .otherwise(pl.lit(0))
      .alias("var_hd2")
)
#   Create a bunch of variables that are functions of the variables created above
df = (
    df.with_columns(c_hd1)
    .with_columns(c_hd2)
    .drop(columns_from_list(df=df, columns="epsilon*"))
    .with_row_index(name="_row_index_")
)
df_original = df

#   Set variables to missing according to the uniform random variables missing_
clear_missing = []
for prefixi in ["hd"]:
    for i in range(1, 3):
        vari = f"var_{prefixi}{i}"
        missingi = f"missing_{prefixi}{i}"

        clear_missing.append(
            pl.when(pl.col(missingi) < impute_share)
            .then(pl.lit(None))
            .otherwise(pl.col(vari))
            .alias(vari)
        )
df = df.with_columns(clear_missing).drop(cs.starts_with("missing_"))

summary(df)


#   Actually do the imputation

#       The list of variables to impute (eventually)
vars_impute = []


#   1) Impute some variables to impute using stat match/hot deck
modeltype = Variable.ModelType.StatMatch
modeltype_binary = Variable.ModelType.HotDeck

#       Hot deck a continuous variable
#           Each model has a set of possible parameters
#           that determine what happens in the model
parameters_hd1 = Parameters.HotDeck(
    #   model_list - a list of variables to match
    #       donors and recipients
    model_list=["var2", "var3", "var5"],
    #   Donate anything other than the variable
    #       (i.e. donate together)
    #       In this case, it's redundant and does nothing...
    donate_list=["var_hd1"],
)

# %%
# Set up the variable to be imputed

logger.info("Impute the boolean variable (var_hd1)")
logger.info("   by setting the model type (a stat match)")
logger.info("   and the list of match variables")
v_hd1 = Variable(
    impute_var="var_hd1",
    modeltype=Variable.ModelType.StatMatch,
    parameters=Parameters.HotDeck(
        model_list=["var2", "var3", "var5"]
    )
)

logger.info("Add the variable to the list to be imputed")
vars_impute.append(v_hd1)


logger.info("Impute the continuous variable (var_hd2) ")
logger.info("   conditional on var_hd1, using narwhals (nw.col('var_hd1'))")
logger.info("   by setting the model type (a hot deck)")
logger.info("   and the list of match variables")
logger.info("   as well as a post-processing edit to set var_hd2=0 when var_hd1==0")

v_hd2 = Variable(
    impute_var="var_hd2",
    Where=nw.col("var_hd1"),
    By=["year", "month"],
    modeltype=Variable.ModelType.HotDeck,
    parameters=Parameters.HotDeck(
        model_list=["var2", "var3", "var5"]
    ),
    postFunctions=(
        nw.when(nw.col("var_hd1"))
          .then(nw.col("var_hd2"))
          .otherwise(nw.lit(0))
          .alias("var_hd2")
    )
)
vars_impute.append(v_hd2)


# %%
logger.info("Set up the imputation")
srmi = SRMI(
    df=df,
    variables=vars_impute,
    n_implicates=2,
    n_iterations=1,
    parallel=False,
    bayesian_bootstrap=True,
    parallel_testing=False,
    path_model=f"{config.path_temp_files}/py_srmi_test_hd",
    force_start=True,
)

# %%
logger.info("Run it")
srmi.run()

# %%
logger.info("Get the results")
_ = df_list = srmi.df_implicates

# %%
logger.info("\n\nLook at the original")
_ = summary(df_original)

logger.info("\n\nLook at the imputes")
_ = df_list.pipe(summary)

logger.info("\n\nLook at the imputes | var_hd1 == 0")
_ = df_list.filter(~nw.col("var_hd1")).pipe(summary)

logger.info("\n\nLook at the imputes | var_hd1 == 1")
_ = df_list.filter(nw.col("var_hd1")).pipe(summary)

Logit and/or OLS-based imputation

import sys
import os
from pathlib import Path

import narwhals as nw
import polars as pl
import polars.selectors as cs

from survey_kit.utilities.random import RandomData
from survey_kit.utilities.dataframe import summary

from survey_kit.imputation.variable import Variable
from survey_kit.imputation.parameters import Parameters
from survey_kit.imputation.srmi import SRMI
from survey_kit.imputation.selection import Selection

from survey_kit import logger, config
from survey_kit.utilities.dataframe import summary, columns_from_list
from survey_kit.utilities.formula_builder import FormulaBuilder


path = Path(config.code_root)
sys.path.append(os.path.normpath(path.parent.parent / "tests"))
from scratch import path_scratch


config.data_root = path_scratch(temp_file_suffix=False)


# %%
# Draw some random data

n_rows = 10_000
impute_share = 0.25


df = (
    RandomData(n_rows=n_rows, seed=32565437)
    .index("index")
    .integer("year", 2016, 2020)
    .integer("month", 1, 12)
    .integer("var2", 0, 10)
    .integer("var3", 0, 50)
    .float("var4", 0, 1)
    .integer("var5", 0, 1)
    .float("unrelated_1", 0, 1)
    .float("unrelated_2", 0, 1)
    .float("unrelated_3", 0, 1)
    .float("unrelated_4", 0, 1)
    .float("unrelated_5", 0, 1)
    .np_distribution("epsilon_reg1", "normal", scale=5)
    .np_distribution("epsilon_reg2", "normal", scale=5)
    .float("missing_reg1", 0, 1)
    .float("missing_reg2", 0, 1)
    .to_df()
)


#   Convenience references to them for creating dependent variables
c_var2 = pl.col("var2")
c_var3 = pl.col("var3")
c_var4 = pl.col("var4")
c_var5 = pl.col("var5")

c_e_reg1 = pl.col("epsilon_reg1")
c_e_reg2 = pl.col("epsilon_reg2")


#   Convenience references to them for creating dependent variables
c_var2 = pl.col("var2")
c_var3 = pl.col("var3")
c_var4 = pl.col("var4")
c_var5 = pl.col("var5")


logger.info("var_reg1 is binary and conditional on other variables")
c_reg1 = ((c_var2 * 2 - c_var3 * 3 * c_var5 + c_e_reg1)  > 0).alias("var_reg1")

logger.info("var_reg2 is != 0 only if var_reg1 == True")
c_reg2 = (
    pl.when(pl.col("var_reg1"))
      .then(((c_var2 * 1.5 - c_var3 * 1 * c_var4 + c_e_reg2)))
      .otherwise(pl.lit(0))
      .alias("var_reg2")
)
#   Create a bunch of variables that are functions of the variables created above
df = (
    df.with_columns(c_reg1)
    .with_columns(c_reg2)
    .drop(columns_from_list(df=df, columns="epsilon*"))
    .with_row_index(name="_row_index_")
)

df_original = df

#   Set variables to missing according to the uniform random variables missing_
clear_missing = []
for prefixi in ["reg"]:
    for i in range(1, 3):
        vari = f"var_{prefixi}{i}"
        missingi = f"missing_{prefixi}{i}"

        clear_missing.append(
            pl.when(pl.col(missingi) < impute_share)
            .then(pl.lit(None))
            .otherwise(pl.col(vari))
            .alias(vari)
        )
df = df.with_columns(clear_missing).drop(cs.starts_with("missing_"))

#   Make a fully collinear var for testing
df = df.with_columns(pl.col("unrelated_1").alias("repeat_1"))


summary(df)


#   Actually do the imputation


# %%
logger.info("Define the regression model (intentionally include some extraneous variables")

f_model = FormulaBuilder(df=df)
f_model.formula_with_varnames_in_brackets(
    "~1+{var_*}+var2+var4+var4*var3*C(var5)+{unrelated_*}+{repeat_*}"
)
logger.info(f_model.formula)


# %%
# Set up the variable to be imputed
vars_impute = []

logger.info("Impute the boolean variable (var_reg1)")
logger.info("   to the default setup for predicted mean matching")
logger.info("   using logit regression")
v_reg1 = Variable(
    impute_var="var_reg1",
    modeltype=Variable.ModelType.pmm,
    model=f_model.formula,
    parameters=Parameters.Regression(model=Parameters.RegressionModel.Logit)
)
logger.info("Add the variable to the list to be imputed")
vars_impute.append(v_reg1)

logger.info("Impute the continuous variable (var_reg2) ")
logger.info("   conditional on var_reg1, using narwhals (nw.col('var_reg1'))")
logger.info("   by setting the model type")
logger.info("   and the formula")
logger.info("   as well as a post-processing edit to set var_reg2=0 when var_reg1==0")
v_reg2 = Variable(
    impute_var="var_reg2",
    Where=nw.col("var_reg1"),
    modeltype=Variable.ModelType.pmm,
    model=f_model.formula,
    #   Default parameters
    parameters=Parameters.Regression(),
    postFunctions=(
        nw.when(nw.col("var_reg1"))
          .then(nw.col("var_reg2"))
          .otherwise(nw.lit(0))
          .alias("var_reg2")
    )
)

vars_impute.append(v_reg2)


# %%
logger.info("Set up the imputation")
logger.info("Add LASSO selection before each imputation")
srmi = SRMI(
    df=df,
    variables=vars_impute,
    n_implicates=2,
    n_iterations=2,
    parallel=False,
    selection=Selection(method=Selection.Method.LASSO),
    modeltype=Variable.ModelType.pmm,
    model=f_model.formula,
    bayesian_bootstrap=True,
    path_model=f"{config.path_temp_files}/py_srmi_test_regression",
    force_start=True,
)

# %%
logger.info("Run it")
srmi.run()


# %%
logger.info("Get the results")
_ = df_list = srmi.df_implicates

# %%
logger.info("\n\nLook at the original")
_ = summary(df_original)

logger.info("\n\nLook at the imputes")
_ = df_list.pipe(summary)

logger.info("\n\nLook at the imputes | var_reg1 == 0")
_ = df_list.filter(~nw.col("var_reg1")).pipe(summary)

logger.info("\n\nLook at the imputes | var_reg1 == 1")
_ = df_list.filter(nw.col("var_reg1")).pipe(summary)

Imputation with LightGBM, see the LightGBM documentation for additional information on some of the options.

import sys
import os
from pathlib import Path

import narwhals as nw
import polars as pl
import polars.selectors as cs

from survey_kit.utilities.random import RandomData
from survey_kit.utilities.dataframe import summary

from survey_kit.imputation.variable import Variable
from survey_kit.imputation.parameters import Parameters
from survey_kit.imputation.srmi import SRMI
from survey_kit.imputation.selection import Selection
import survey_kit.imputation.utilities.lightgbm_wrapper as rep_lgbm
from survey_kit.imputation.utilities.lightgbm_wrapper import Tuner_optuna

from survey_kit import logger, config
from survey_kit.utilities.dataframe import summary, columns_from_list


# %%
# Draw some random data

n_rows = 10_000
impute_share = 0.25


df = (
    RandomData(n_rows=n_rows, seed=32565437)
    .index("index")
    .integer("year", 2016, 2020)
    .integer("month", 1, 12)
    .integer("var2", 0, 10)
    .integer("var3", 0, 50)
    .float("var4", 0, 1)
    .integer("var5", 0, 1)
    .float("unrelated_1", 0, 1)
    .float("unrelated_2", 0, 1)
    .float("unrelated_3", 0, 1)
    .float("unrelated_4", 0, 1)
    .float("unrelated_5", 0, 1)
    .np_distribution("epsilon_gbm1", "normal", scale=5)
    .np_distribution("epsilon_gbm2", "normal", scale=5)
    .np_distribution("epsilon_gbm3", "normal", scale=5)
    .float("missing_gbm1", 0, 1)
    .float("missing_gbm2", 0, 1)
    .float("missing_gbm3", 0, 1)
    .to_df()
)


#   Convenience references to them for creating dependent variables
c_var2 = pl.col("var2")
c_var3 = pl.col("var3")
c_var4 = pl.col("var4")
c_var5 = pl.col("var5")

c_e_gbm1 = pl.col("epsilon_gbm1")
c_e_gbm2 = pl.col("epsilon_gbm2")


#   Convenience references to them for creating dependent variables
c_var2 = pl.col("var2")
c_var3 = pl.col("var3")
c_var4 = pl.col("var4")
c_var5 = pl.col("var5")


logger.info("var_gbm1 is binary and conditional on other variables")
c_gbm1 = ((c_var2 * 2 - c_var3 * 3 * c_var5 + c_e_gbm1)  > 0).alias("var_gbm1")

logger.info("var_gbm2 is != 0 only if var_gbm1 == True")
c_gbm2 = (
    pl.when(pl.col("var_gbm1"))
      .then(((c_var2 * 1.5 - c_var3 * 1 * c_var4 + c_e_gbm2)))
      .otherwise(pl.lit(0))
      .alias("var_gbm2")
)

c_gbm3 = (
    pl.when(pl.col("var_gbm1"))
      .then(((c_var2 * 1.5 - c_var3 * 1 * c_var4 + c_e_gbm2)))
      .otherwise(pl.lit(0))
      .alias("var_gbm3")
)
#   Create a bunch of variables that are functions of the variables created above
df = (
    df.with_columns(c_gbm1)
    .with_columns(c_gbm2, c_gbm3)
    .drop(columns_from_list(df=df, columns="epsilon*"))
    .with_row_index(name="_row_index_")
)
df_original = df

#   Set variables to missing according to the uniform random variables missing_
clear_missing = []
for prefixi in ["gbm"]:
    for i in range(1, 4):
        vari = f"var_{prefixi}{i}"
        missingi = f"missing_{prefixi}{i}"

        clear_missing.append(
            pl.when(pl.col(missingi) < impute_share)
            .then(pl.lit(None))
            .otherwise(pl.col(vari))
            .alias(vari)
        )
df = df.with_columns(clear_missing).drop(cs.starts_with("missing_"))

#   Make a fully collinear var for testing
df = df.with_columns(pl.col("unrelated_1").alias("repeat_1"))


summary(df)


# %%
logger.info("Define some dummy functions to run after imputation of 2")
#   Test a simple pre-post function
#       These would get run gets run in each iteration (in each implicate)
#           before (preFunctions) or after (postFunctions) this variable is imputed
#   Notes for these functions:
#       1) No type hints on imported package types (will throw an error)
#           i.e. no df:pl.DataFrame or -> pl.DataFrame
#       2) Must be completely self-contained (i.e. all imports within the function)
#           This has to do with how it gets saved and loaded in async calls
#       3) Effectively, you have to assume it'll be called
#           in an environment with no imports before it
def square_var(df, var_to_square: str, name: str):
    import narwhals as nw

    return (
        nw.from_native(df)
        .with_columns((nw.col(var_to_square) ** 2).alias(name))
        .to_native()
    )


def recalculate_interaction(df, var1: str, var2: str, name: str):
    import narwhals as nw

    return (
        nw.from_native(df)
        .with_columns((nw.col(var1) * nw.col(var2)).alias(name))
        .to_native()
    )


# %%
logger.info("Set up hyperparameter tuning")
tuner = Tuner_optuna(
    n_trials=50, objective=rep_lgbm.Tuner.Objectives.mae, test_size=0.25
)

logger.info("   Set the tuner parameters to the defaults")
tuner.parameters()

logger.info("   Then specify ranges to check between as follow")
tuner.hyperparameters["num_leaves"] = [2, 256]
tuner.hyperparameters["max_depth"] = [2, 256]
tuner.hyperparameters["min_data_in_leaf"] = [10, 250]
tuner.hyperparameters["num_iterations"] = [25, 200]
tuner.hyperparameters["bagging_fraction"] = [0.5, 1]
tuner.hyperparameters["bagging_freq"] = [1, 5]




vars_impute = []

# %%
logger.info("Impute the boolean variable (var_gbm1)")
logger.info("   to the default setup for predicted mean matching")
logger.info("   using lightgbm")
logger.info("   (you can pass a formula, but you don't need to)")

logger.info("First, set up the lightgbm parameters")
logger.info("   This says, do hyperparameter tuning first (tune)")
logger.info("   Redo it at each run (tune_overwrite)")
logger.info("   And sets the lightgbm parameter defaults (parameters) that the tuning can overwrite")
parameters_lgbm1 = Parameters.LightGBM(
    tune=True,
    tune_hyperparameter_path=f"{config.data_root}/tuner_outputs",
    tuner=tuner,
    tune_overwrite=True,
    parameters={
        "objective": "binary",
        "num_leaves": 32,
        "min_data_in_leaf": 20,
        "num_iterations": 100,
        "test_size": 0.2,
        "boosting": "gbdt",
        "categorical_feature": ["var5"],
        "verbose": -1,  # ,
    },
    error=Parameters.ErrorDraw.pmm,
)


logger.info("Actually define the variable and the model")
v_gbm1 = Variable(
    impute_var="var_gbm1",
    model=["var_*", "var4", "var3", "var5", "unrelated_*", "repeat_*"],
    modeltype=Variable.ModelType.LightGBM,
    parameters=parameters_lgbm1
)
logger.info("Add the variable to the list to be imputed")
vars_impute.append(v_gbm1)






logger.info("Impute the continuous variable (var_gbm2) ")
logger.info("   conditional on var_gbm1, using narwhals (nw.col('var_gbm1'))")
logger.info("   as well as a post-processing edit to set var_gbm2=0 when var_gbm1==0")
logger.info("   and some other random post-processing")
logger.info("Different parameters for the continuous variable")
parameters_lgbm2 = Parameters.LightGBM(
    tune=True,
    tune_hyperparameter_path=f"{config.data_root}/tuner_outputs",
    tuner=tuner,
    tune_overwrite=True,
    parameters={
        "objective": "regression",
        "num_leaves": 32,
        "min_data_in_leaf": 20,
        "num_iterations": 100,
        "test_size": 0.2,
        "boosting": "gbdt",
        "categorical_feature": ["var5"],
        "verbose": -1,  # ,
    },
    error=Parameters.ErrorDraw.pmm,
)

v_gbm2 = Variable(
    impute_var="var_gbm2",
    Where=nw.col("var_gbm1"),
    #   Needed in case var_gbm1 changes between iterations
    Where_predict=(nw.col("var_gbm2") != 0),
    model=["var_*", "var4", "var3", "var5", "unrelated_*", "repeat_*"],
    modeltype=Variable.ModelType.LightGBM,
    parameters=parameters_lgbm2,
    postFunctions=[
        (
            nw.when(nw.col("var_gbm1"))
            .then(nw.col("var_gbm2"))
            .otherwise(nw.lit(0))
            .alias("var_gbm2")
        ),
        Variable.PrePost.Function(
            recalculate_interaction,
            parameters=dict(
                var1="var_gbm1",
                var2="var_gbm2",
                name="var_gbm12"
            ),
        ),
        Variable.PrePost.Function(
            square_var,
            parameters=dict(
                var_to_square="var_gbm2", 
                name="var_gbm2_sq"
            )
        ),
    ]
)

vars_impute.append(v_gbm2)




logger.info("Now do one with the quantile-regression lightgbm")
logger.info("   To do this, pass quantiles and set objective='quantile'")
parameters_lgbm3 = Parameters.LightGBM(
    tune=True,
    tune_hyperparameter_path=f"{config.data_root}/tuner_outputs",
    tuner=tuner,
    tune_overwrite=True,
    quantiles=[0.25,0.5,0.75],
    parameters={
        "objective": "quantile",
        "num_leaves": 32,
        "min_data_in_leaf": 20,
        "num_iterations": 100,
        "test_size": 0.2,
        "boosting": "gbdt",
        "categorical_feature": ["var5"],
        "verbose": -1,  # ,
    },
    error=Parameters.ErrorDraw.pmm,
)

v_gbm3 = Variable(
    impute_var="var_gbm3",
    Where=nw.col("var_gbm1"),
    #   Needed in case var_gbm1 changes between iterations
    Where_predict=(nw.col("var_gbm3") != 0),
    model=["var_*", "var4", "var3", "var5", "unrelated_*", "repeat_*"],
    modeltype=Variable.ModelType.LightGBM,
    parameters=parameters_lgbm3,
    postFunctions=[
        (
            nw.when(nw.col("var_gbm1"))
            .then(nw.col("var_gbm3"))
            .otherwise(nw.lit(0))
            .alias("var_gbm3")
        )
    ]
)

vars_impute.append(v_gbm3)


# %%
logger.info("Set up the imputation")
srmi = SRMI(
    df=df,
    variables=vars_impute,
    n_implicates=2,
    n_iterations=2,
    parallel=False,
    index=["index"],
    modeltype=Variable.ModelType.pmm,
    bayesian_bootstrap=True,
    path_model=f"{config.path_temp_files}/py_srmi_test_gbm",
    force_start=True,
)

# %%
logger.info("Run it")
srmi.run()

logger.info("It's automatically saved and can be loaded with (see path_model above):")
logger.info("path_model = f'{config.path_temp_files}/py_srmi_test_gbm'")
logger.info("srmi = SRMI.load(path_model)")


# %%
logger.info("Get the results")
_ = df_list = srmi.df_implicates

# %%
logger.info("\n\nLook at the original")
_ = summary(df_original,detailed=True,drb_round=True)

logger.info("\n\nLook at the imputes")
_ = df_list.pipe(summary,detailed=True,drb_round=True)

logger.info("\n\nLook at the imputes | var_gbm1 == 0")
_ = df_list.filter(~nw.col("var_gbm1")).pipe(summary,detailed=True,drb_round=True)

logger.info("\n\nLook at the imputes | var_gbm1 == 1")
_ = df_list.filter(nw.col("var_gbm1")).pipe(summary,detailed=True,drb_round=True)