Skip to content

Core API Reference

mmm_eval.core

Core validation functionality for MMM frameworks.

Classes

BaseValidationTest(date_column: str)

Bases: ABC

Abstract base class for validation tests.

All validation tests must inherit from this class and implement the required methods to provide a unified testing interface.

Initialize the validation test.

Source code in mmm_eval/core/base_validation_test.py
def __init__(self, date_column: str):
    """Initialize the validation test."""
    self.date_column = date_column
    self.rng = np.random.default_rng(ValidationTestConstants.RANDOM_STATE)
Attributes
test_name: str abstractmethod property

Return the name of the test.

Returns Test name (e.g., 'accuracy', 'stability')

Functions
run(adapter: BaseAdapter, data: pd.DataFrame) -> ValidationTestResult abstractmethod

Run the validation test.

Parameters:

Name Type Description Default
adapter BaseAdapter

The adapter to validate

required
data DataFrame

Input data for validation

required

Returns:

Type Description
ValidationTestResult

TestResult object containing test results

Source code in mmm_eval/core/base_validation_test.py
@abstractmethod
def run(self, adapter: BaseAdapter, data: pd.DataFrame) -> "ValidationTestResult":
    """Run the validation test.

    Args:
        adapter: The adapter to validate
        data: Input data for validation

    Returns:
        TestResult object containing test results

    """
    pass
run_with_error_handling(adapter: BaseAdapter, data: pd.DataFrame) -> ValidationTestResult

Run the validation test with error handling.

Parameters:

Name Type Description Default
adapter BaseAdapter

The adapter to validate

required
data DataFrame

Input data for validation

required

Returns:

Type Description
ValidationTestResult

TestResult object containing test results

Raises:

Type Description
MetricCalculationError

If metric calculation fails

TestExecutionError

If test execution fails

Source code in mmm_eval/core/base_validation_test.py
def run_with_error_handling(self, adapter: BaseAdapter, data: pd.DataFrame) -> "ValidationTestResult":
    """Run the validation test with error handling.

    Args:
        adapter: The adapter to validate
        data: Input data for validation

    Returns:
        TestResult object containing test results

    Raises:
        MetricCalculationError: If metric calculation fails
        TestExecutionError: If test execution fails

    """
    try:
        return self.run(adapter, data)
    except ZeroDivisionError as e:
        # This is clearly a mathematical calculation issue
        raise MetricCalculationError(f"Metric calculation error in {self.test_name} test: {str(e)}") from e
    except Exception as e:
        # All other errors - let individual tests handle specific categorization if needed
        raise TestExecutionError(f"Test execution error in {self.test_name} test: {str(e)}") from e

ValidationResults(test_results: dict[ValidationTestNames, ValidationTestResult])

Container for complete validation results.

This class holds the results of all validation tests run, including individual test results and overall summary.

Initialize validation results.

Parameters:

Name Type Description Default
test_results dict[ValidationTestNames, ValidationTestResult]

Dictionary mapping test names to their results

required
Source code in mmm_eval/core/validation_test_results.py
def __init__(self, test_results: dict[ValidationTestNames, ValidationTestResult]):
    """Initialize validation results.

    Args:
        test_results: Dictionary mapping test names to their results

    """
    self.test_results = test_results
Functions
get_test_result(test_name: ValidationTestNames) -> ValidationTestResult

Get results for a specific test.

Source code in mmm_eval/core/validation_test_results.py
def get_test_result(self, test_name: ValidationTestNames) -> ValidationTestResult:
    """Get results for a specific test."""
    return self.test_results[test_name]
to_df() -> pd.DataFrame

Convert validation results to a flat DataFrame format.

Source code in mmm_eval/core/validation_test_results.py
def to_df(self) -> pd.DataFrame:
    """Convert validation results to a flat DataFrame format."""
    return pd.concat(
        [self.get_test_result(test_name).to_df() for test_name in self.test_results.keys()],
        ignore_index=True,
    )

ValidationTestOrchestrator()

Main orchestrator for running validation tests.

This class manages the test registry and executes tests in sequence, aggregating their results.

Initialize the validator with standard tests pre-registered.

Source code in mmm_eval/core/validation_test_orchestrator.py
def __init__(self):
    """Initialize the validator with standard tests pre-registered."""
    self.tests: dict[ValidationTestNames, type[BaseValidationTest]] = {
        ValidationTestNames.ACCURACY: AccuracyTest,
        ValidationTestNames.CROSS_VALIDATION: CrossValidationTest,
        ValidationTestNames.REFRESH_STABILITY: RefreshStabilityTest,
        ValidationTestNames.PERTURBATION: PerturbationTest,
    }
Functions
validate(adapter: BaseAdapter, data: pd.DataFrame, test_names: list[ValidationTestNames]) -> ValidationResults

Run validation tests on the model.

Parameters:

Name Type Description Default
model

Model to validate

required
data DataFrame

Input data for validation

required
test_names list[ValidationTestNames]

List of test names to run

required
adapter BaseAdapter

Adapter to use for the test

required

Returns:

Type Description
ValidationResults

ValidationResults containing all test results

Raises:

Type Description
ValueError

If any requested test is not registered

Source code in mmm_eval/core/validation_test_orchestrator.py
def validate(
    self,
    adapter: BaseAdapter,
    data: pd.DataFrame,
    test_names: list[ValidationTestNames],
) -> ValidationResults:
    """Run validation tests on the model.

    Args:
        model: Model to validate
        data: Input data for validation
        test_names: List of test names to run
        adapter: Adapter to use for the test

    Returns:
        ValidationResults containing all test results

    Raises:
        ValueError: If any requested test is not registered

    """
    # Run tests and collect results
    results: dict[ValidationTestNames, ValidationTestResult] = {}
    for test_name in test_names:
        logger.info(f"Running test: {test_name}")
        test_instance = self.tests[test_name](adapter.date_column)
        test_result = test_instance.run_with_error_handling(adapter, data)
        results[test_name] = test_result

    return ValidationResults(results)

ValidationTestResult(test_name: ValidationTestNames, metric_names: list[str], test_scores: AccuracyMetricResults | CrossValidationMetricResults | RefreshStabilityMetricResults | PerturbationMetricResults)

Container for individual test results.

This class holds the results of a single validation test, including pass/fail status, metrics, and any error messages.

Initialize test results.

Parameters:

Name Type Description Default
test_name ValidationTestNames

Name of the test

required
metric_names list[str]

List of metric names

required
test_scores AccuracyMetricResults | CrossValidationMetricResults | RefreshStabilityMetricResults | PerturbationMetricResults

Computed metric results

required
Source code in mmm_eval/core/validation_test_results.py
def __init__(
    self,
    test_name: ValidationTestNames,
    metric_names: list[str],
    test_scores: (
        AccuracyMetricResults
        | CrossValidationMetricResults
        | RefreshStabilityMetricResults
        | PerturbationMetricResults
    ),
):
    """Initialize test results.

    Args:
        test_name: Name of the test
        metric_names: List of metric names
        test_scores: Computed metric results

    """
    self.test_name = test_name
    self.metric_names = metric_names
    self.test_scores = test_scores
    self.timestamp = datetime.now()
Functions
to_df() -> pd.DataFrame

Convert test results to a flat DataFrame format.

Source code in mmm_eval/core/validation_test_results.py
def to_df(self) -> pd.DataFrame:
    """Convert test results to a flat DataFrame format."""
    test_scores_df = self.test_scores.to_df()
    test_scores_df[ValidationTestAttributeNames.TEST_NAME.value] = self.test_name.value
    test_scores_df[ValidationTestAttributeNames.TIMESTAMP.value] = self.timestamp.isoformat()
    return test_scores_df

Modules

base_validation_test

Abstract base classes for MMM validation framework.

Classes
BaseValidationTest(date_column: str)

Bases: ABC

Abstract base class for validation tests.

All validation tests must inherit from this class and implement the required methods to provide a unified testing interface.

Initialize the validation test.

Source code in mmm_eval/core/base_validation_test.py
def __init__(self, date_column: str):
    """Initialize the validation test."""
    self.date_column = date_column
    self.rng = np.random.default_rng(ValidationTestConstants.RANDOM_STATE)
Attributes
test_name: str abstractmethod property

Return the name of the test.

Returns Test name (e.g., 'accuracy', 'stability')

Functions
run(adapter: BaseAdapter, data: pd.DataFrame) -> ValidationTestResult abstractmethod

Run the validation test.

Parameters:

Name Type Description Default
adapter BaseAdapter

The adapter to validate

required
data DataFrame

Input data for validation

required

Returns:

Type Description
ValidationTestResult

TestResult object containing test results

Source code in mmm_eval/core/base_validation_test.py
@abstractmethod
def run(self, adapter: BaseAdapter, data: pd.DataFrame) -> "ValidationTestResult":
    """Run the validation test.

    Args:
        adapter: The adapter to validate
        data: Input data for validation

    Returns:
        TestResult object containing test results

    """
    pass
run_with_error_handling(adapter: BaseAdapter, data: pd.DataFrame) -> ValidationTestResult

Run the validation test with error handling.

Parameters:

Name Type Description Default
adapter BaseAdapter

The adapter to validate

required
data DataFrame

Input data for validation

required

Returns:

Type Description
ValidationTestResult

TestResult object containing test results

Raises:

Type Description
MetricCalculationError

If metric calculation fails

TestExecutionError

If test execution fails

Source code in mmm_eval/core/base_validation_test.py
def run_with_error_handling(self, adapter: BaseAdapter, data: pd.DataFrame) -> "ValidationTestResult":
    """Run the validation test with error handling.

    Args:
        adapter: The adapter to validate
        data: Input data for validation

    Returns:
        TestResult object containing test results

    Raises:
        MetricCalculationError: If metric calculation fails
        TestExecutionError: If test execution fails

    """
    try:
        return self.run(adapter, data)
    except ZeroDivisionError as e:
        # This is clearly a mathematical calculation issue
        raise MetricCalculationError(f"Metric calculation error in {self.test_name} test: {str(e)}") from e
    except Exception as e:
        # All other errors - let individual tests handle specific categorization if needed
        raise TestExecutionError(f"Test execution error in {self.test_name} test: {str(e)}") from e
Functions
split_timeseries_cv(data: pd.DataFrame, n_splits: PositiveInt, test_size: PositiveInt, date_column: str) -> Generator[tuple[np.ndarray, np.ndarray], None, None]

Produce train/test masks for rolling CV, split globally based on date.

This simulates regular refreshes and utilises the last test_size data points for testing in the first fold, using all prior data for training. For a dataset with T dates, the subsequen test folds follow the pattern [T-4, T], [T-8, T-4], ...

Parameters:

Name Type Description Default
data DataFrame

dataframe of MMM data to be split

required
n_splits PositiveInt

number of unique folds to generate

required
test_size PositiveInt

the number of observations in each testing fold

required
date_column str

the name of the date column in the dataframe to split by

required

Yields:

Type Description
tuple[ndarray, ndarray]

integer masks corresponding training and test set indices.

Source code in mmm_eval/core/base_validation_test.py
def split_timeseries_cv(
    data: pd.DataFrame, n_splits: PositiveInt, test_size: PositiveInt, date_column: str
) -> Generator[tuple[np.ndarray, np.ndarray], None, None]:
    """Produce train/test masks for rolling CV, split globally based on date.

    This simulates regular refreshes and utilises the last `test_size` data points for
    testing in the first fold, using all prior data for training. For a dataset with
    T dates, the subsequen test folds follow the pattern [T-4, T], [T-8, T-4], ...

    Arguments:
        data: dataframe of MMM data to be split
        n_splits: number of unique folds to generate
        test_size: the number of observations in each testing fold
        date_column: the name of the date column in the dataframe to split by

    Yields:
        integer masks corresponding training and test set indices.

    """
    sorted_dates = sorted(data[date_column].unique())
    n_dates = len(sorted_dates)

    # assuming the minimum training set size allowable is equal to `test_size`, ensure there's
    # enough data temporally to do the splits
    n_required_dates = test_size * (n_splits + 1)
    if n_dates < n_required_dates:
        raise ValueError(
            "Insufficient timeseries data provided for splitting. In order to "
            f"perform {n_splits} splits with test_size={test_size}, at least "
            f"{n_required_dates} unique dates are required, but only {n_dates} "
            f"dates are available."
        )

    for i in range(n_splits):
        test_end = n_dates - i * test_size
        test_start = n_dates - (i + 1) * test_size
        test_dates = sorted_dates[test_start:test_end]
        train_dates = sorted_dates[:test_start]

        train_mask = data[date_column].isin(train_dates)
        test_mask = data[date_column].isin(test_dates)
        yield train_mask, test_mask
split_timeseries_data(data: pd.DataFrame, test_proportion: PositiveFloat, date_column: str) -> tuple[np.ndarray, np.ndarray]

Split data globally based on date.

Parameters:

Name Type Description Default
data DataFrame

timeseries data to split, possibly with another index like geography

required
test_proportion PositiveFloat

proportion of test data, must be in (0, 1)

required
date_column str

name of the date column

required

Returns:

Type Description
tuple[ndarray, ndarray]

boolean masks for training and test data respectively

Source code in mmm_eval/core/base_validation_test.py
def split_timeseries_data(
    data: pd.DataFrame, test_proportion: PositiveFloat, date_column: str
) -> tuple[np.ndarray, np.ndarray]:
    """Split data globally based on date.

    Arguments:
        data: timeseries data to split, possibly with another index like geography
        test_proportion: proportion of test data, must be in (0, 1)
        date_column: name of the date column

    Returns:
        boolean masks for training and test data respectively

    """
    if test_proportion <= 0 or test_proportion >= 1:
        raise ValueError("`test_proportion` must be in the range (0, 1)")

    sorted_dates = sorted(data[date_column].unique())
    # rounding eliminates possibility of floating point precision issues
    split_idx = int(round(len(sorted_dates) * (1 - test_proportion)))
    cutoff = sorted_dates[split_idx]

    train_mask = data[date_column] < cutoff
    test_mask = data[date_column] >= cutoff

    return train_mask, test_mask

constants

Classes
ValidationTestConstants

Constants for the validation tests.

Classes
PerturbationConstants

Constants for the perturbation test.

evaluator

Main evaluator for MMM frameworks.

Classes
Evaluator(data: pd.DataFrame, test_names: tuple[str, ...] | None = None)

Main evaluator class for MMM frameworks.

This class provides a unified interface for evaluating different MMM frameworks using standardized validation tests.

Initialize the evaluator.

Source code in mmm_eval/core/evaluator.py
def __init__(self, data: pd.DataFrame, test_names: tuple[str, ...] | None = None):
    """Initialize the evaluator."""
    self.validation_orchestrator = ValidationTestOrchestrator()
    self.data = data
    self.test_names = (
        self._get_test_names(test_names) if test_names else self.validation_orchestrator._get_all_test_names()
    )
Functions
evaluate_framework(framework: str, config: BaseConfig) -> ValidationResults

Evaluate an MMM framework using the unified API.

Parameters:

Name Type Description Default
framework str

Name of the MMM framework to evaluate

required
config BaseConfig

Framework-specific configuration

required

Returns:

Type Description
ValidationResults

ValidationResult object containing evaluation metrics and predictions

Raises:

Type Description
ValueError

If any test name is invalid

Source code in mmm_eval/core/evaluator.py
def evaluate_framework(self, framework: str, config: BaseConfig) -> ValidationResults:
    """Evaluate an MMM framework using the unified API.

    Args:
        framework: Name of the MMM framework to evaluate
        config: Framework-specific configuration

    Returns:
        ValidationResult object containing evaluation metrics and predictions

    Raises:
        ValueError: If any test name is invalid

    """
    # Initialize the adapter
    adapter = get_adapter(framework, config)

    # Run validation tests
    validation_results = self.validation_orchestrator.validate(
        adapter=adapter,
        data=self.data,
        test_names=self.test_names,
    )

    return validation_results
Functions

exceptions

Custom exceptions for MMM validation framework.

Classes
InvalidTestNameError

Bases: ValidationError

Raised when an invalid test name is provided.

MetricCalculationError

Bases: ValidationError

Raised when metric calculation fails.

TestExecutionError

Bases: ValidationError

Raised when test execution fails.

ValidationError

Bases: Exception

Base exception for validation framework errors.

run_evaluation

Classes
Functions
run_evaluation(framework: str, data: pd.DataFrame, config: BaseConfig, test_names: tuple[str, ...] | None = None) -> pd.DataFrame

Evaluate an MMM framework.

Parameters:

Name Type Description Default
framework str

The framework to evaluate.

required
data DataFrame

The data to evaluate.

required
config BaseConfig

The config to use for the evaluation.

required
test_names tuple[str, ...] | None

The tests to run. If not provided, all tests will be run.

None

Returns:

Type Description
DataFrame

A pandas DataFrame containing the evaluation results.

Source code in mmm_eval/core/run_evaluation.py
def run_evaluation(
    framework: str,
    data: pd.DataFrame,
    config: BaseConfig,
    test_names: tuple[str, ...] | None = None,
) -> pd.DataFrame:
    """Evaluate an MMM framework.

    Args:
        framework: The framework to evaluate.
        data: The data to evaluate.
        config: The config to use for the evaluation.
        test_names: The tests to run. If not provided, all tests will be run.

    Returns:
        A pandas DataFrame containing the evaluation results.

    """
    # validate + process the input data
    data = DataPipeline(
        data=data,
        framework=framework,
        date_column=config.date_column,
        response_column=config.response_column,
        revenue_column=config.revenue_column,
        control_columns=config.control_columns,
        channel_columns=config.channel_columns,
    ).run()

    # run the evaluation suite
    results = Evaluator(
        data=data,
        test_names=test_names,
    ).evaluate_framework(framework=framework, config=config)

    return results.to_df()

validation_test_orchestrator

Test orchestrator for MMM validation framework.

Classes
ValidationTestOrchestrator()

Main orchestrator for running validation tests.

This class manages the test registry and executes tests in sequence, aggregating their results.

Initialize the validator with standard tests pre-registered.

Source code in mmm_eval/core/validation_test_orchestrator.py
def __init__(self):
    """Initialize the validator with standard tests pre-registered."""
    self.tests: dict[ValidationTestNames, type[BaseValidationTest]] = {
        ValidationTestNames.ACCURACY: AccuracyTest,
        ValidationTestNames.CROSS_VALIDATION: CrossValidationTest,
        ValidationTestNames.REFRESH_STABILITY: RefreshStabilityTest,
        ValidationTestNames.PERTURBATION: PerturbationTest,
    }
Functions
validate(adapter: BaseAdapter, data: pd.DataFrame, test_names: list[ValidationTestNames]) -> ValidationResults

Run validation tests on the model.

Parameters:

Name Type Description Default
model

Model to validate

required
data DataFrame

Input data for validation

required
test_names list[ValidationTestNames]

List of test names to run

required
adapter BaseAdapter

Adapter to use for the test

required

Returns:

Type Description
ValidationResults

ValidationResults containing all test results

Raises:

Type Description
ValueError

If any requested test is not registered

Source code in mmm_eval/core/validation_test_orchestrator.py
def validate(
    self,
    adapter: BaseAdapter,
    data: pd.DataFrame,
    test_names: list[ValidationTestNames],
) -> ValidationResults:
    """Run validation tests on the model.

    Args:
        model: Model to validate
        data: Input data for validation
        test_names: List of test names to run
        adapter: Adapter to use for the test

    Returns:
        ValidationResults containing all test results

    Raises:
        ValueError: If any requested test is not registered

    """
    # Run tests and collect results
    results: dict[ValidationTestNames, ValidationTestResult] = {}
    for test_name in test_names:
        logger.info(f"Running test: {test_name}")
        test_instance = self.tests[test_name](adapter.date_column)
        test_result = test_instance.run_with_error_handling(adapter, data)
        results[test_name] = test_result

    return ValidationResults(results)

validation_test_results

Result containers for MMM validation framework.

Classes
ValidationResults(test_results: dict[ValidationTestNames, ValidationTestResult])

Container for complete validation results.

This class holds the results of all validation tests run, including individual test results and overall summary.

Initialize validation results.

Parameters:

Name Type Description Default
test_results dict[ValidationTestNames, ValidationTestResult]

Dictionary mapping test names to their results

required
Source code in mmm_eval/core/validation_test_results.py
def __init__(self, test_results: dict[ValidationTestNames, ValidationTestResult]):
    """Initialize validation results.

    Args:
        test_results: Dictionary mapping test names to their results

    """
    self.test_results = test_results
Functions
get_test_result(test_name: ValidationTestNames) -> ValidationTestResult

Get results for a specific test.

Source code in mmm_eval/core/validation_test_results.py
def get_test_result(self, test_name: ValidationTestNames) -> ValidationTestResult:
    """Get results for a specific test."""
    return self.test_results[test_name]
to_df() -> pd.DataFrame

Convert validation results to a flat DataFrame format.

Source code in mmm_eval/core/validation_test_results.py
def to_df(self) -> pd.DataFrame:
    """Convert validation results to a flat DataFrame format."""
    return pd.concat(
        [self.get_test_result(test_name).to_df() for test_name in self.test_results.keys()],
        ignore_index=True,
    )
ValidationTestResult(test_name: ValidationTestNames, metric_names: list[str], test_scores: AccuracyMetricResults | CrossValidationMetricResults | RefreshStabilityMetricResults | PerturbationMetricResults)

Container for individual test results.

This class holds the results of a single validation test, including pass/fail status, metrics, and any error messages.

Initialize test results.

Parameters:

Name Type Description Default
test_name ValidationTestNames

Name of the test

required
metric_names list[str]

List of metric names

required
test_scores AccuracyMetricResults | CrossValidationMetricResults | RefreshStabilityMetricResults | PerturbationMetricResults

Computed metric results

required
Source code in mmm_eval/core/validation_test_results.py
def __init__(
    self,
    test_name: ValidationTestNames,
    metric_names: list[str],
    test_scores: (
        AccuracyMetricResults
        | CrossValidationMetricResults
        | RefreshStabilityMetricResults
        | PerturbationMetricResults
    ),
):
    """Initialize test results.

    Args:
        test_name: Name of the test
        metric_names: List of metric names
        test_scores: Computed metric results

    """
    self.test_name = test_name
    self.metric_names = metric_names
    self.test_scores = test_scores
    self.timestamp = datetime.now()
Functions
to_df() -> pd.DataFrame

Convert test results to a flat DataFrame format.

Source code in mmm_eval/core/validation_test_results.py
def to_df(self) -> pd.DataFrame:
    """Convert test results to a flat DataFrame format."""
    test_scores_df = self.test_scores.to_df()
    test_scores_df[ValidationTestAttributeNames.TEST_NAME.value] = self.test_name.value
    test_scores_df[ValidationTestAttributeNames.TIMESTAMP.value] = self.timestamp.isoformat()
    return test_scores_df

validation_tests

Classes
AccuracyTest(date_column: str)

Bases: BaseValidationTest

Validation test for model accuracy using holdout validation.

This test evaluates model performance by splitting data into train/test sets and calculating MAPE and R-squared metrics on the test set.

Source code in mmm_eval/core/base_validation_test.py
def __init__(self, date_column: str):
    """Initialize the validation test."""
    self.date_column = date_column
    self.rng = np.random.default_rng(ValidationTestConstants.RANDOM_STATE)
Attributes
test_name: ValidationTestNames property

Return the name of the test.

Functions
run(adapter: BaseAdapter, data: pd.DataFrame) -> ValidationTestResult

Run the accuracy test.

Source code in mmm_eval/core/validation_tests.py
def run(self, adapter: BaseAdapter, data: pd.DataFrame) -> ValidationTestResult:
    """Run the accuracy test."""
    # Split data into train/test sets
    train, test = self._split_data_holdout(data)
    predictions = adapter.fit_and_predict(train, test)
    actual = test.groupby(self.date_column)[InputDataframeConstants.RESPONSE_COL].sum()
    assert len(actual) == len(predictions), "Actual and predicted lengths must match"

    # Calculate metrics
    test_scores = AccuracyMetricResults.populate_object_with_metrics(
        actual=pd.Series(actual),  # Ensure it's a Series
        predicted=pd.Series(predictions, index=actual.index),
    )

    logger.info(f"Saving the test results for {self.test_name} test")

    return ValidationTestResult(
        test_name=ValidationTestNames.ACCURACY,
        metric_names=AccuracyMetricNames.to_list(),
        test_scores=test_scores,
    )
CrossValidationTest(date_column: str)

Bases: BaseValidationTest

Validation test for the cross-validation of the MMM framework.

Source code in mmm_eval/core/base_validation_test.py
def __init__(self, date_column: str):
    """Initialize the validation test."""
    self.date_column = date_column
    self.rng = np.random.default_rng(ValidationTestConstants.RANDOM_STATE)
Attributes
test_name: ValidationTestNames property

Return the name of the test.

Functions
run(adapter: BaseAdapter, data: pd.DataFrame) -> ValidationTestResult

Run the cross-validation test using time-series splits.

Parameters:

Name Type Description Default
model

Model to validate

required
adapter BaseAdapter

Adapter to use for the test

required
data DataFrame

Input data

required

Returns:

Type Description
ValidationTestResult

TestResult containing cross-validation metrics

Source code in mmm_eval/core/validation_tests.py
def run(self, adapter: BaseAdapter, data: pd.DataFrame) -> ValidationTestResult:
    """Run the cross-validation test using time-series splits.

    Args:
        model: Model to validate
        adapter: Adapter to use for the test
        data: Input data

    Returns:
        TestResult containing cross-validation metrics

    """
    # Initialize cross-validation splitter
    cv_splits = self._split_data_time_series_cv(data)

    # Store metrics for each fold
    fold_metrics = []

    # Run cross-validation
    for i, (train_idx, test_idx) in enumerate(cv_splits):

        logger.info(f"Running cross-validation fold {i+1} of {len(cv_splits)}")

        # Get train/test data
        train = data.loc[train_idx]
        test = data.loc[test_idx]

        # Get predictions
        predictions = adapter.fit_and_predict(train, test)
        actual = test.groupby(self.date_column)[InputDataframeConstants.RESPONSE_COL].sum()
        assert len(actual) == len(predictions), "Actual and predicted lengths must match"

        # Add in fold results
        fold_metrics.append(
            AccuracyMetricResults.populate_object_with_metrics(
                actual=pd.Series(actual),  # Ensure it's a Series
                predicted=pd.Series(predictions, index=actual.index),
            )
        )

    # Calculate mean and std of metrics across folds and create metric results
    test_scores = CrossValidationMetricResults(
        mean_mape=calculate_mean_for_singular_values_across_cross_validation_folds(
            fold_metrics, AccuracyMetricNames.MAPE
        ),
        std_mape=calculate_std_for_singular_values_across_cross_validation_folds(
            fold_metrics, AccuracyMetricNames.MAPE
        ),
        mean_r_squared=calculate_mean_for_singular_values_across_cross_validation_folds(
            fold_metrics, AccuracyMetricNames.R_SQUARED
        ),
    )

    logger.info(f"Saving the test results for {self.test_name} test")

    return ValidationTestResult(
        test_name=ValidationTestNames.CROSS_VALIDATION,
        metric_names=CrossValidationMetricNames.to_list(),
        test_scores=test_scores,
    )
PerturbationTest(date_column: str)

Bases: BaseValidationTest

Validation test for the perturbation of the MMM framework.

Source code in mmm_eval/core/base_validation_test.py
def __init__(self, date_column: str):
    """Initialize the validation test."""
    self.date_column = date_column
    self.rng = np.random.default_rng(ValidationTestConstants.RANDOM_STATE)
Attributes
test_name: ValidationTestNames property

Return the name of the test.

Functions
run(adapter: BaseAdapter, data: pd.DataFrame) -> ValidationTestResult

Run the perturbation test.

Source code in mmm_eval/core/validation_tests.py
def run(self, adapter: BaseAdapter, data: pd.DataFrame) -> ValidationTestResult:
    """Run the perturbation test."""
    # Train model on original data
    adapter.fit(data)
    original_rois = adapter.get_channel_roi()

    # TODO: support perturbation of reach and frequency regressors
    if adapter.primary_media_regressor_type == PrimaryMediaRegressor.REACH_AND_FREQUENCY:
        logger.warning(
            "Perturbation test skipped: Reach and frequency regressor type not supported for perturbation."
        )
        # Return NaN results for each channel indicating the test was not run
        channel_names = adapter.get_channel_names()
        test_scores = PerturbationMetricResults(
            percentage_change_for_each_channel=pd.Series(np.nan, index=channel_names),
        )
        return ValidationTestResult(
            test_name=ValidationTestNames.PERTURBATION,
            metric_names=PerturbationMetricNames.to_list(),
            test_scores=test_scores,
        )

    # Add noise to primary regressor data and retrain
    noisy_data = self._add_gaussian_noise_to_primary_regressors(
        df=data,
        regressor_cols=adapter.primary_media_regressor_columns,
    )
    adapter.fit(noisy_data)
    noise_rois = adapter.get_channel_roi()

    # calculate the pct change in roi
    percentage_change = calculate_absolute_percentage_change(
        baseline_series=original_rois,
        comparison_series=noise_rois,
    )

    # Create metric results - roi % change for each channel
    test_scores = PerturbationMetricResults(
        percentage_change_for_each_channel=percentage_change,
    )

    logger.info(f"Saving the test results for {self.test_name} test")

    return ValidationTestResult(
        test_name=ValidationTestNames.PERTURBATION,
        metric_names=PerturbationMetricNames.to_list(),
        test_scores=test_scores,
    )
RefreshStabilityTest(date_column: str)

Bases: BaseValidationTest

Validation test for the stability of the MMM framework.

Source code in mmm_eval/core/base_validation_test.py
def __init__(self, date_column: str):
    """Initialize the validation test."""
    self.date_column = date_column
    self.rng = np.random.default_rng(ValidationTestConstants.RANDOM_STATE)
Attributes
test_name: ValidationTestNames property

Return the name of the test.

Functions
run(adapter: BaseAdapter, data: pd.DataFrame) -> ValidationTestResult

Run the stability test.

Source code in mmm_eval/core/validation_tests.py
def run(self, adapter: BaseAdapter, data: pd.DataFrame) -> ValidationTestResult:
    """Run the stability test."""
    # Initialize cross-validation splitter
    cv_splits = self._split_data_time_series_cv(data)

    # Store metrics for each fold
    fold_metrics = []

    # Run cross-validation
    for i, (train_idx, refresh_idx) in enumerate(cv_splits):

        logger.info(f"Running refresh stability test fold {i+1} of {len(cv_splits)}")

        # Get train/test data
        # todo(): Can we somehow store these training changes in the adapter for use in time series holdout test
        current_data = data.loc[train_idx]
        # Combine current data with refresh data for retraining
        refresh_data = pd.concat([current_data, data.loc[refresh_idx]], ignore_index=True)
        # Get common dates for roi stability comparison
        common_start_date, common_end_date = self._get_common_dates(
            baseline_data=current_data,
            comparison_data=refresh_data,
            date_column=adapter.date_column,
        )

        # Train model and get coefficients
        adapter.fit(current_data)
        current_model_rois = adapter.get_channel_roi(
            start_date=common_start_date,
            end_date=common_end_date,
        )
        adapter.fit(refresh_data)
        refreshed_model_rois = adapter.get_channel_roi(
            start_date=common_start_date,
            end_date=common_end_date,
        )

        # calculate the pct change in volume
        percentage_change = calculate_absolute_percentage_change(
            baseline_series=current_model_rois,
            comparison_series=refreshed_model_rois,
        )

        fold_metrics.append(percentage_change)

    # Calculate mean and std of percentage change for each channel across cross validation folds
    test_scores = RefreshStabilityMetricResults(
        mean_percentage_change_for_each_channel=calculate_means_for_series_across_cross_validation_folds(
            fold_metrics
        ),
        std_percentage_change_for_each_channel=calculate_stds_for_series_across_cross_validation_folds(
            fold_metrics
        ),
    )

    logger.info(f"Saving the test results for {self.test_name} test")

    return ValidationTestResult(
        test_name=ValidationTestNames.REFRESH_STABILITY,
        metric_names=RefreshStabilityMetricNames.to_list(),
        test_scores=test_scores,
    )
Functions

validation_tests_models

Classes
ValidationResultAttributeNames

Bases: StrEnum

Define the names of the validation result attributes.

ValidationTestAttributeNames

Bases: StrEnum

Define the names of the validation test attributes.

ValidationTestNames

Bases: StrEnum

Define the names of the validation tests.

Functions
all_tests() -> list[ValidationTestNames] classmethod

Return all validation test names as a list.

Source code in mmm_eval/core/validation_tests_models.py
@classmethod
def all_tests(cls) -> list["ValidationTestNames"]:
    """Return all validation test names as a list."""
    return list(cls)
all_tests_as_str() -> list[str] classmethod

Return all validation test names as a list of strings.

Source code in mmm_eval/core/validation_tests_models.py
@classmethod
def all_tests_as_str(cls) -> list[str]:
    """Return all validation test names as a list of strings."""
    return [test.value for test in cls]