Source code for baybe.utils.augmentation

"""Utilities related to data augmentation."""

from collections.abc import Collection, Sequence
from itertools import permutations, product

import pandas as pd


[docs] def df_apply_permutation_augmentation( df: pd.DataFrame, column_groups: Sequence[Sequence[str]], ) -> pd.DataFrame: """Augment a dataframe if permutation invariant columns are present. Args: df: The dataframe that should be augmented. column_groups: Sequences of permutation invariant columns. The n'th column in each group will be permuted together with each n'th column in the other groups. Returns: The augmented dataframe containing the original one. Augmented row indices are identical with the index of their original row. Raises: ValueError: If less than two column groups are given. ValueError: If a column group is empty. ValueError: If the column groups have differing amounts of entries. Examples: >>> df = pd.DataFrame({'A1':[1,2],'A2':[3,4], 'B1': [5, 6], 'B2': [7, 8]}) >>> df A1 A2 B1 B2 0 1 3 5 7 1 2 4 6 8 >>> column_groups = [['A1'], ['A2']] >>> dfa = df_apply_permutation_augmentation(df, column_groups) >>> dfa A1 A2 B1 B2 0 1 3 5 7 0 3 1 5 7 1 2 4 6 8 1 4 2 6 8 >>> column_groups = [['A1', 'B1'], ['A2', 'B2']] >>> dfa = df_apply_permutation_augmentation(df, column_groups) >>> dfa A1 A2 B1 B2 0 1 3 5 7 0 3 1 7 5 1 2 4 6 8 1 4 2 8 6 """ # Validation if len(column_groups) < 2: raise ValueError( "When augmenting permutation invariance, at least two column sequences " "must be given." ) if len({len(seq) for seq in column_groups}) != 1: raise ValueError( "Permutation augmentation can only work if the amount of columns in each " "sequence is the same." ) elif len(column_groups[0]) < 1: raise ValueError( "Permutation augmentation can only work if each column group has at " "least one entry." ) # Augmentation Loop new_rows: list[pd.DataFrame] = [] idx_permutation = list(permutations(range(len(column_groups)))) for _, row in df.iterrows(): # For each row in the original df, collect all its permutations to_add = [] for perm in idx_permutation: new_row = row.copy() # Permute columns, this is done separately for each tuple of columns that # belong together for deps in map(list, zip(*column_groups)): new_row[deps] = row[[deps[k] for k in perm]] to_add.append(new_row) # Store augmented rows, keeping the index of their original row new_rows.append( pd.DataFrame(to_add, columns=df.columns, index=[row.name] * len(to_add)) ) return pd.concat(new_rows)
[docs] def df_apply_dependency_augmentation( df: pd.DataFrame, causing: tuple[str, Sequence], affected: Collection[tuple[str, Sequence]], ) -> pd.DataFrame: """Augment a dataframe if dependency invariant columns are present. This works with the concept of column-values pairs for causing and affected column. Any row present where the specified causing column has one of the provided values will trigger an augmentation on the affected columns. The latter are augmented by going through all their invariant values and adding respective new rows. Args: df: The dataframe that should be augmented. causing: Causing column name and its causing values. affected: Affected columns and their invariant values. Returns: The augmented dataframe containing the original one. Augmented row indices are identical with the index of their original row. Examples: >>> df = pd.DataFrame({'A':[0,1],'B':[2,3], 'C': [5, 5], 'D': [6, 7]}) >>> df A B C D 0 0 2 5 6 1 1 3 5 7 >>> causing = ('A', [0]) >>> affected = [('B', [2,3,4])] >>> dfa = df_apply_dependency_augmentation(df, causing, affected) >>> dfa A B C D 0 0 2 5 6 0 0 3 5 6 0 0 4 5 6 1 1 3 5 7 >>> causing = ('A', [0]) >>> affected = [('B', [2,3,4])] >>> dfa = df_apply_dependency_augmentation(df, causing, affected) >>> dfa A B C D 0 0 2 5 6 0 0 3 5 6 0 0 4 5 6 1 1 3 5 7 >>> causing = ('A', [0, 1]) >>> affected = [('B', [2,3])] >>> dfa = df_apply_dependency_augmentation(df, causing, affected) >>> dfa A B C D 0 0 2 5 6 0 0 3 5 6 1 1 2 5 7 1 1 3 5 7 >>> causing = ('A', [0]) >>> affected = [('B', [2,3]), ('C', [5, 6])] >>> dfa = df_apply_dependency_augmentation(df, causing, affected) >>> dfa A B C D 0 0 2 5 6 0 0 2 6 6 0 0 3 5 6 0 0 3 6 6 1 1 3 5 7 """ new_rows: list[pd.DataFrame] = [] col_causing, vals_causing = causing affected_cols, affected_inv_vals = zip(*affected) affected_inv_vals_combinations = list(product(*affected_inv_vals)) # Iterate through all rows that have a causing value in the respective column. for _, row in df.iterrows(): to_add = [] if row[col_causing] not in vals_causing: # Just keep unaffected rows without augmentation to_add.append(row) else: # Create augmented rows by assigning the affected columns all possible # values to_add += [ pd.Series({**row.to_dict(), **dict(zip(affected_cols, values))}) for values in affected_inv_vals_combinations ] # Store augmented rows, keeping the index of their original row new_rows.append( pd.DataFrame(to_add, columns=df.columns, index=[row.name] * len(to_add)) ) return pd.concat(new_rows)