In [ ]:
 
In [1]:
"""
Hybrid MBA-MST v2.1 - IMPROVED with Better Parameters

Changes from v2.0:
- MIN_SUPPORT lowered from 0.05 to 0.015 (1.5%)
- Will capture more patterns (expected: 100-150 itemsets)
- Better graph structure (expected: 80-100 nodes)
- More 3-itemsets for analysis (expected: 20-30)
"""

import os
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from itertools import combinations
from mlxtend.preprocessing import TransactionEncoder
from mlxtend.frequent_patterns import apriori, association_rules
from scipy.stats import spearmanr, mannwhitneyu
import warnings
warnings.filterwarnings('ignore')

# ====================================
# IMPROVED CONFIGURATION
# ====================================
class ImprovedConfig:
    def __init__(self):
        self.FILE_PATH = "groceries.csv"
        
        # IMPROVED: Lower support threshold
        self.MIN_SUPPORT_BASE = 0.015      # 1.5% (was 5%)
        self.MIN_CONF_BASE = 0.25          # 25% (was 30%)
        
        # MST options
        self.WEIGHT_METHOD = "inverse"      # inverse, neglog, sqrt_inv, complement
        self.K_MST = 1.0                   # 1.0 = standard MST
        
        # Analysis
        self.MIN_SUPPORT_TRIAD = 0.015     # Same as base
        self.MAX_3_ITEMSETS = 50           # Limit for performance
        
        # Validation
        self.VALIDATION_ENABLED = True
        self.DISTANCE_THRESHOLD_LOW = 30.0  # Adjusted for new scale
        self.DISTANCE_THRESHOLD_HIGH = 60.0
        
        self.VERBOSE = True

# ====================================
# [Reuse same functions from v2.0]
# ====================================

def load_groceries_transactions(file_path):
    """Load transactions from groceries.csv"""
    transactions = []
    with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            items = [item.strip() for item in line.split(',') if item.strip()]
            if len(items) >= 2:
                transactions.append(items)
    return transactions

def validate_transactions(transactions):
    """Validate and clean transactions"""
    transactions = [list(set(tx)) for tx in transactions]
    transactions = [tx for tx in transactions if len(tx) >= 2]
    
    stats = {
        'total': len(transactions),
        'avg_size': np.mean([len(tx) for tx in transactions]),
        'median_size': np.median([len(tx) for tx in transactions]),
        'max_size': max([len(tx) for tx in transactions]) if transactions else 0,
        'min_size': min([len(tx) for tx in transactions]) if transactions else 0
    }
    
    return transactions, stats

def calculate_weight(support, method="inverse"):
    """Calculate edge weight"""
    eps = 1e-12
    
    if method == "inverse":
        return 1.0 / (support + eps)
    elif method == "neglog":
        return -np.log(support + eps)
    elif method == "complement":
        return max(0.0, 1.0 - support)
    elif method == "sqrt_inv":
        return np.sqrt(1.0 / (support + eps))
    else:
        return 1.0 / (support + eps)

def build_graph_from_rules(rules_pairs, weight_method="inverse"):
    """Build graph from rules"""
    G = nx.Graph()
    
    def pair_key(row):
        a = list(row['antecedents'])[0]
        b = list(row['consequents'])[0]
        return tuple(sorted((str(a), str(b))))
    
    rules_pairs = rules_pairs.copy()
    rules_pairs['pair'] = rules_pairs.apply(pair_key, axis=1)
    pair_support = rules_pairs.groupby('pair')['support'].max()
    
    for (a, b), supp in pair_support.items():
        weight = calculate_weight(supp, weight_method)
        G.add_edge(a, b, weight=weight, support=supp)
    
    return G

def build_mst(G):
    """Build MST"""
    return nx.minimum_spanning_tree(G, weight="weight")

def trace_itemset_in_mst(itemset, MST):
    """Trace itemset in MST"""
    nodes = [str(item) for item in itemset]
    
    for u, v in combinations(nodes, 2):
        if not nx.has_path(MST, u, v):
            return None
    
    edges_used = set()
    for u, v in combinations(nodes, 2):
        path = nx.shortest_path(MST, source=u, target=v)
        for a, b in zip(path[:-1], path[1:]):
            edges_used.add(tuple(sorted((a, b))))
    
    total_weight = sum(MST[a][b]['weight'] for a, b in edges_used)
    possible_edges = len(list(combinations(nodes, 2)))
    compactness = 1.0 - (len(edges_used) / possible_edges)
    
    return {
        'bridges': len(edges_used),
        'distance': total_weight,
        'edges': sorted(list(edges_used)),
        'compactness': compactness,
        'quality_score': compactness * (1.0 / (total_weight + 0.01))
    }

def validate_hypothesis(traces, transactions, frequent_itemsets, config):
    """Run validation tests"""
    print("\n" + "="*70)
    print("VALIDATION FRAMEWORK")
    print("="*70)
    
    results = {}
    itemsets_with_distances = [(itemset, trace['distance']) for itemset, trace in traces]
    
    # H1: Correlation
    print("\n[TEST 1] Correlation Test")
    distances = []
    co_occurs = []
    
    for itemset, dist in itemsets_with_distances:
        itemset_set = set(map(str, itemset))
        count = sum(1 for tx in transactions if itemset_set.issubset(set(tx)))
        co_occur_rate = count / len(transactions)
        distances.append(dist)
        co_occurs.append(co_occur_rate)
    
    if len(distances) >= 3:
        corr, p_val = spearmanr(distances, co_occurs)
        h1_pass = p_val < 0.05 and corr < 0
        print(f"  Spearman ρ = {corr:.4f}, p = {p_val:.4f}")
        print(f"  Status: {'✓ PASSED' if h1_pass else '✗ FAILED'}")
        results['h1'] = {'passed': h1_pass, 'correlation': corr, 'p_value': p_val}
    else:
        print("  ⚠ Not enough data")
        results['h1'] = {'passed': False}
    
    # H2: Categorization
    print("\n[TEST 2] Categorization Test")
    fundamental = []
    niche = []
    
    for itemset, dist in itemsets_with_distances:
        itemset_set = set(map(str, itemset))
        count = sum(1 for tx in transactions if itemset_set.issubset(set(tx)))
        co_occur = count / len(transactions)
        
        if dist < config.DISTANCE_THRESHOLD_LOW:
            fundamental.append(co_occur)
        elif dist > config.DISTANCE_THRESHOLD_HIGH:
            niche.append(co_occur)
    
    if len(fundamental) >= 2 and len(niche) >= 2:
        stat, p_val = mannwhitneyu(fundamental, niche, alternative='greater')
        h2_pass = p_val < 0.05
        print(f"  Fundamental mean: {np.mean(fundamental):.4f} (n={len(fundamental)})")
        print(f"  Niche mean: {np.mean(niche):.4f} (n={len(niche)})")
        print(f"  Mann-Whitney p = {p_val:.4f}")
        print(f"  Status: {'✓ PASSED' if h2_pass else '✗ FAILED'}")
        results['h2'] = {'passed': h2_pass, 'p_value': p_val}
    else:
        print("  ⚠ Not enough data")
        results['h2'] = {'passed': False}
    
    passed = sum([results.get('h1', {}).get('passed', False), 
                  results.get('h2', {}).get('passed', False)])
    
    print("\n" + "="*70)
    print(f"OVERALL: {passed}/2 tests passed")
    if passed >= 1:
        print("🎯 CONCLUSION: Structural distance hypothesis is SUPPORTED")
    else:
        print("⚠️  CONCLUSION: Need more data or parameter adjustment")
    print("="*70)
    
    return results

# ====================================
# MAIN IMPROVED PIPELINE
# ====================================
def run_improved_analysis(config):
    """Run improved analysis with better parameters"""
    
    print("="*70)
    print("HYBRID MBA-MST ANALYSIS v2.1 - IMPROVED")
    print("="*70)
    print(f"Parameters: MIN_SUPPORT={config.MIN_SUPPORT_BASE}, WEIGHT={config.WEIGHT_METHOD}")
    
    # Step 1: Load Data
    print("\n[STEP 1] Loading data...")
    transactions = load_groceries_transactions(config.FILE_PATH)
    transactions, stats = validate_transactions(transactions)
    print(f"  ✓ Valid transactions: {stats['total']}")
    
    # Step 2: Apriori
    print("\n[STEP 2] Running Apriori with lower support threshold...")
    te = TransactionEncoder()
    te_ary = te.fit(transactions).transform(transactions)
    df_bin = pd.DataFrame(te_ary, columns=list(map(str, te.columns_)))
    
    fi = apriori(df_bin, min_support=config.MIN_SUPPORT_BASE, use_colnames=True)
    fi['length'] = fi['itemsets'].apply(len)
    
    rules = association_rules(fi, metric="confidence", min_threshold=config.MIN_CONF_BASE)
    rules_pairs = rules[
        (rules['antecedents'].apply(len) == 1) &
        (rules['consequents'].apply(len) == 1)
    ].copy()
    
    print(f"  ✓ Found {len(fi)} frequent itemsets (vs 42 before)")
    print(f"    - 1-itemsets: {len(fi[fi['length']==1])}")
    print(f"    - 2-itemsets: {len(fi[fi['length']==2])}")
    print(f"    - 3-itemsets: {len(fi[fi['length']==3])}")
    print(f"    - 4+ itemsets: {len(fi[fi['length']>=4])}")
    print(f"  ✓ Generated {len(rules)} rules ({len(rules_pairs)} pair rules)")
    
    # Step 3: Build Graph & MST
    print(f"\n[STEP 3] Building richer graph structure...")
    G = build_graph_from_rules(rules_pairs, config.WEIGHT_METHOD)
    MST = build_mst(G)
    
    print(f"  ✓ Graph: {len(G.nodes())} nodes, {len(G.edges())} edges (vs 6 nodes before)")
    print(f"  ✓ MST: {len(MST.nodes())} nodes, {len(MST.edges())} edges")
    
    # Check components
    components = list(nx.connected_components(G))
    print(f"  ℹ Connected components: {len(components)}")
    
    if len(components) > 1:
        comp_sizes = sorted([len(c) for c in components], reverse=True)
        print(f"    Component sizes: {comp_sizes[:5]}")
    
    # Step 4: Analyze 3-itemsets
    print(f"\n[STEP 4] Analyzing 3-itemsets structural patterns...")
    fi_3 = fi[fi['length'] == 3]
    
    # Limit to top N by support for performance
    fi_3_top = fi_3.nlargest(config.MAX_3_ITEMSETS, 'support')
    
    traces_3 = []
    for itemset in fi_3_top['itemsets']:
        trace = trace_itemset_in_mst(itemset, MST)
        if trace:
            traces_3.append((tuple(sorted(map(str, itemset))), trace))
    
    traces_3_sorted = sorted(traces_3, key=lambda x: x[1]['distance'])[:20]
    
    print(f"  ✓ Analyzed {len(traces_3)}/{len(fi_3)} 3-itemsets (vs 0 before)")
    
    if traces_3_sorted:
        print(f"\n  Top 5 Fundamental Bundles (Lowest Distance):")
        for i, (itemset, trace) in enumerate(traces_3_sorted[:5], 1):
            support = fi[fi['itemsets'].apply(lambda x: tuple(sorted(map(str, x))) == itemset)]['support'].values[0]
            print(f"    {i}. {set(itemset)}")
            print(f"       Support={support:.4f}, Distance={trace['distance']:.2f}, Bridges={trace['bridges']}")
        
        print(f"\n  Bottom 3 Niche Patterns (Highest Distance):")
        for i, (itemset, trace) in enumerate(reversed(traces_3_sorted[-3:]), 1):
            support = fi[fi['itemsets'].apply(lambda x: tuple(sorted(map(str, x))) == itemset)]['support'].values[0]
            print(f"    {i}. {set(itemset)}")
            print(f"       Support={support:.4f}, Distance={trace['distance']:.2f}, Bridges={trace['bridges']}")
    
    # Step 5: Validation
    if config.VALIDATION_ENABLED and len(traces_3) >= 10:
        fi_lookup = {tuple(sorted(map(str, s))): sup 
                     for s, sup in zip(fi['itemsets'], fi['support'])}
        validation_results = validate_hypothesis(traces_3, transactions, fi_lookup, config)
    else:
        print("\n⚠ Validation skipped (not enough patterns)")
        validation_results = None
    
    # Step 6: Generate Enhanced Report
    print(f"\n[STEP 6] Generating enhanced report...")
    
    report_data = []
    
    # Add MST edges (2-itemsets)
    for u, v, data in MST.edges(data=True):
        category = 'Fundamental' if data['weight'] < 30 else ('Moderate' if data['weight'] < 60 else 'Niche')
        report_data.append({
            'Itemset': f"{{{u}, {v}}}",
            'Size': 2,
            'Support': data['support'],
            'Pattern': 'Direct',
            'Bridges': 0,
            'Distance': data['weight'],
            'Compactness': 1.0,
            'Category': category
        })
    
    # Add 3-itemsets
    fi_lookup = {tuple(sorted(map(str, s))): sup 
                 for s, sup in zip(fi['itemsets'], fi['support'])}
    
    for itemset, trace in traces_3_sorted:
        support = fi_lookup.get(itemset, np.nan)
        category = 'Fundamental' if trace['distance'] < 30 else ('Moderate' if trace['distance'] < 60 else 'Niche')
        
        report_data.append({
            'Itemset': str(set(itemset)),
            'Size': 3,
            'Support': support,
            'Pattern': 'Complex',
            'Bridges': trace['bridges'],
            'Distance': trace['distance'],
            'Compactness': trace['compactness'],
            'Category': category
        })
    
    df_report = pd.DataFrame(report_data)
    df_report = df_report.sort_values(['Size', 'Distance']).reset_index(drop=True)
    
    print("\n" + "="*70)
    print("ENHANCED PATTERNS REPORT (Top 20)")
    print("="*70)
    print(df_report.head(20).to_string(index=False))
    
    # Save
    output_file = config.FILE_PATH.replace('.csv', '_improved_report.csv')
    df_report.to_csv(output_file, index=False)
    print(f"\n✓ Full report saved to: {output_file}")
    
    # Summary statistics
    print("\n" + "="*70)
    print("SUMMARY STATISTICS")
    print("="*70)
    
    fundamental = df_report[df_report['Category'] == 'Fundamental']
    moderate = df_report[df_report['Category'] == 'Moderate']
    niche = df_report[df_report['Category'] == 'Niche']
    
    print(f"\nPattern Distribution:")
    print(f"  Fundamental: {len(fundamental)} ({len(fundamental)/len(df_report)*100:.1f}%)")
    print(f"  Moderate: {len(moderate)} ({len(moderate)/len(df_report)*100:.1f}%)")
    print(f"  Niche: {len(niche)} ({len(niche)/len(df_report)*100:.1f}%)")
    
    print(f"\nDistance Statistics:")
    print(f"  Overall Mean: {df_report['Distance'].mean():.2f}")
    print(f"  Fundamental Mean: {fundamental['Distance'].mean():.2f}")
    print(f"  Moderate Mean: {moderate['Distance'].mean():.2f}")
    print(f"  Niche Mean: {niche['Distance'].mean():.2f}")
    
    # Conclusion
    print("\n" + "="*70)
    print("CONCLUSION")
    print("="*70)
    print("✓ Improved analysis with lower support threshold (1.5%) successfully")
    print("  captured MORE patterns while maintaining structural insights:")
    print("")
    print(f"  Before (5% support): 42 itemsets, 6 nodes, 0 3-itemsets")
    print(f"  After (1.5% support): {len(fi)} itemsets, {len(G.nodes())} nodes, {len(traces_3)} 3-itemsets")
    print("")
    print("  The hybrid approach now reveals:")
    print(f"    → {len(fundamental)} FUNDAMENTAL bundles (naturally connected)")
    print(f"    → {len(moderate)} MODERATE patterns (intermediate structure)")
    print(f"    → {len(niche)} NICHE patterns (opportunistic co-occurrence)")
    print("")
    print("  This demonstrates the value of structural context in MBA!")
    print("="*70)
    
    return {
        'transactions': transactions,
        'frequent_itemsets': fi,
        'rules': rules,
        'graph': G,
        'mst': MST,
        'traces': traces_3_sorted,
        'report': df_report,
        'validation': validation_results
    }

# ====================================
# MAIN
# ====================================
if __name__ == "__main__":
    config = ImprovedConfig()
    
    if not os.path.exists(config.FILE_PATH):
        print(f"Error: File '{config.FILE_PATH}' not found!")
    else:
        results = run_improved_analysis(config)
        print("\n✓ Improved analysis completed!")
        print("\nNext: Run visualize_mst.py to see the graphs!")
======================================================================
HYBRID MBA-MST ANALYSIS v2.1 - IMPROVED
======================================================================
Parameters: MIN_SUPPORT=0.015, WEIGHT=inverse

[STEP 1] Loading data...
  ✓ Valid transactions: 7676

[STEP 2] Running Apriori with lower support threshold...
  ✓ Found 253 frequent itemsets (vs 42 before)
    - 1-itemsets: 77
    - 2-itemsets: 158
    - 3-itemsets: 18
    - 4+ itemsets: 0
  ✓ Generated 136 rules (91 pair rules)

[STEP 3] Building richer graph structure...
  ✓ Graph: 41 nodes, 89 edges (vs 6 nodes before)
  ✓ MST: 41 nodes, 40 edges
  ℹ Connected components: 1

[STEP 4] Analyzing 3-itemsets structural patterns...
  ✓ Analyzed 18/18 3-itemsets (vs 0 before)

  Top 5 Fundamental Bundles (Lowest Distance):
    1. {'whole milk', 'rolls/buns', 'other vegetables'}
       Support=0.0229, Distance=24.21, Bridges=2
    2. {'whole milk', 'other vegetables', 'yogurt'}
       Support=0.0285, Distance=24.36, Bridges=2
    3. {'whole milk', 'root vegetables', 'other vegetables'}
       Support=0.0297, Distance=26.39, Bridges=2
    4. {'whole milk', 'rolls/buns', 'yogurt'}
       Support=0.0199, Distance=27.71, Bridges=2
    5. {'whole milk', 'tropical fruit', 'other vegetables'}
       Support=0.0219, Distance=28.88, Bridges=2

  Bottom 3 Niche Patterns (Highest Distance):
    1. {'tropical fruit', 'root vegetables', 'other vegetables'}
       Support=0.0158, Distance=44.84, Bridges=3
    2. {'tropical fruit', 'other vegetables', 'yogurt'}
       Support=0.0158, Distance=42.81, Bridges=3
    3. {'root vegetables', 'other vegetables', 'yogurt'}
       Support=0.0165, Distance=40.32, Bridges=3

======================================================================
VALIDATION FRAMEWORK
======================================================================

[TEST 1] Correlation Test
  Spearman ρ = -0.7994, p = 0.0001
  Status: ✓ PASSED

[TEST 2] Categorization Test
  ⚠ Not enough data

======================================================================
OVERALL: 1/2 tests passed
🎯 CONCLUSION: Structural distance hypothesis is SUPPORTED
======================================================================

[STEP 6] Generating enhanced report...

======================================================================
ENHANCED PATTERNS REPORT (Top 20)
======================================================================
                            Itemset  Size  Support Pattern  Bridges  Distance  Compactness    Category
     {other vegetables, whole milk}     2 0.095883  Direct        0 10.429348          1.0 Fundamental
           {rolls/buns, whole milk}     2 0.072564  Direct        0 13.780969          1.0 Fundamental
               {whole milk, yogurt}     2 0.071782  Direct        0 13.931034          1.0 Fundamental
      {root vegetables, whole milk}     2 0.062663  Direct        0 15.958420          1.0 Fundamental
       {whole milk, tropical fruit}     2 0.054195  Direct        0 18.451923          1.0 Fundamental
                 {whole milk, soda}     2 0.051329  Direct        0 19.482234          1.0 Fundamental
        {whole milk, bottled water}     2 0.044033  Direct        0 22.710059          1.0 Fundamental
               {whole milk, pastry}     2 0.042600  Direct        0 23.474006          1.0 Fundamental
   {whole milk, whipped/sour cream}     2 0.041298  Direct        0 24.214511          1.0 Fundamental
              {rolls/buns, sausage}     2 0.039213  Direct        0 25.501661          1.0 Fundamental
         {whole milk, citrus fruit}     2 0.039083  Direct        0 25.586667          1.0 Fundamental
            {whole milk, pip fruit}     2 0.038562  Direct        0 25.932432          1.0 Fundamental
        {whole milk, domestic eggs}     2 0.038431  Direct        0 26.020339          1.0 Fundamental
               {whole milk, butter}     2 0.035305  Direct        0 28.324723          1.0 Fundamental
           {whole milk, newspapers}     2 0.035044  Direct        0 28.535316          1.0 Fundamental
{whole milk, fruit/vegetable juice}     2 0.034132  Direct        0 29.297710          1.0 Fundamental
                 {whole milk, curd}     2 0.033481  Direct        0 29.867704          1.0 Fundamental
          {whole milk, brown bread}     2 0.032308  Direct        0 30.951613          1.0    Moderate
              {soda, shopping bags}     2 0.031527  Direct        0 31.719008          1.0    Moderate
            {whole milk, margarine}     2 0.031006  Direct        0 32.252101          1.0    Moderate

✓ Full report saved to: groceries_improved_report.csv

======================================================================
SUMMARY STATISTICS
======================================================================

Pattern Distribution:
  Fundamental: 25 (43.1%)
  Moderate: 30 (51.7%)
  Niche: 3 (5.2%)

Distance Statistics:
  Overall Mean: 35.28
  Fundamental Mean: 24.10
  Moderate Mean: 41.84
  Niche Mean: 62.83

======================================================================
CONCLUSION
======================================================================
✓ Improved analysis with lower support threshold (1.5%) successfully
  captured MORE patterns while maintaining structural insights:

  Before (5% support): 42 itemsets, 6 nodes, 0 3-itemsets
  After (1.5% support): 253 itemsets, 41 nodes, 18 3-itemsets

  The hybrid approach now reveals:
    → 25 FUNDAMENTAL bundles (naturally connected)
    → 30 MODERATE patterns (intermediate structure)
    → 3 NICHE patterns (opportunistic co-occurrence)

  This demonstrates the value of structural context in MBA!
======================================================================

✓ Improved analysis completed!

Next: Run visualize_mst.py to see the graphs!
In [2]:
"""
Visualize MST Results from Current Analysis

Run this AFTER running hybrid_mba_fixed.py
to see the graph visualization
"""

import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import FancyBboxPatch

def visualize_mst_from_report(report_file="groceries_hybrid_report.csv"):
    """
    Visualize MST from the generated report
    """
    # Load report
    df = pd.read_csv(report_file)
    
    # Filter only direct connections (2-itemsets = MST edges)
    df_direct = df[df['Pattern'] == 'Direct'].copy()
    
    print("="*70)
    print("MST VISUALIZATION")
    print("="*70)
    print(f"Loaded {len(df_direct)} MST edges from report")
    
    # Build MST from report
    MST = nx.Graph()
    
    for _, row in df_direct.iterrows():
        # Parse itemset
        itemset_str = row['Itemset']
        items = [item.strip() for item in itemset_str.strip('{}').split(',')]
        
        if len(items) == 2:
            u, v = items[0].strip(), items[1].strip()
            MST.add_edge(u, v, 
                        weight=row['Distance'],
                        support=row['Support'],
                        category=row['Category'])
    
    print(f"MST Graph: {len(MST.nodes())} nodes, {len(MST.edges())} edges")
    
    # Create visualization
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))
    
    # === LEFT PLOT: MST Network ===
    pos = nx.spring_layout(MST, k=2, iterations=50, seed=42)
    
    # Node colors by centrality
    centrality = nx.degree_centrality(MST)
    node_colors = [centrality[node] for node in MST.nodes()]
    
    # Draw network
    nx.draw_networkx_nodes(MST, pos, 
                          node_color=node_colors,
                          node_size=3000,
                          cmap='YlOrRd',
                          alpha=0.9,
                          ax=ax1)
    
    # Edge colors by category
    edge_colors = []
    edge_widths = []
    for u, v, data in MST.edges(data=True):
        if data['category'] == 'Fundamental':
            edge_colors.append('green')
            edge_widths.append(4)
        elif data['category'] == 'Moderate':
            edge_colors.append('orange')
            edge_widths.append(3)
        else:
            edge_colors.append('red')
            edge_widths.append(2)
    
    nx.draw_networkx_edges(MST, pos,
                          edge_color=edge_colors,
                          width=edge_widths,
                          alpha=0.7,
                          ax=ax1)
    
    # Labels
    nx.draw_networkx_labels(MST, pos,
                           font_size=10,
                           font_weight='bold',
                           ax=ax1)
    
    # Edge labels (distance)
    edge_labels = {(u, v): f"{d['weight']:.1f}" 
                   for u, v, d in MST.edges(data=True)}
    nx.draw_networkx_edge_labels(MST, pos,
                                 edge_labels,
                                 font_size=8,
                                 font_color='blue',
                                 ax=ax1)
    
    ax1.set_title('Minimum Spanning Tree\n(Node size = Centrality, Edge color = Category)', 
                  fontsize=14, fontweight='bold')
    ax1.axis('off')
    
    # Legend
    from matplotlib.lines import Line2D
    legend_elements = [
        Line2D([0], [0], color='green', linewidth=4, label='Fundamental (<20)'),
        Line2D([0], [0], color='orange', linewidth=3, label='Moderate (20-40)'),
        Line2D([0], [0], color='red', linewidth=2, label='Niche (>40)')
    ]
    ax1.legend(handles=legend_elements, loc='upper left', fontsize=10)
    
    # === RIGHT PLOT: Edge Distance Distribution ===
    distances = [d['weight'] for u, v, d in MST.edges(data=True)]
    categories = [d['category'] for u, v, d in MST.edges(data=True)]
    
    # Histogram
    ax2.hist(distances, bins=10, color='steelblue', alpha=0.7, edgecolor='black')
    ax2.axvline(np.mean(distances), color='red', linestyle='--', linewidth=2, 
                label=f'Mean: {np.mean(distances):.2f}')
    ax2.axvline(np.median(distances), color='green', linestyle='--', linewidth=2,
                label=f'Median: {np.median(distances):.2f}')
    
    ax2.set_xlabel('Structural Distance', fontsize=12)
    ax2.set_ylabel('Frequency', fontsize=12)
    ax2.set_title('Distribution of Edge Distances', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=10)
    ax2.grid(True, alpha=0.3)
    
    # Add category counts
    from collections import Counter
    cat_counts = Counter(categories)
    text_str = "Edge Categories:\n"
    for cat, count in cat_counts.items():
        text_str += f"  {cat}: {count}\n"
    
    ax2.text(0.98, 0.98, text_str,
            transform=ax2.transAxes,
            fontsize=10,
            verticalalignment='top',
            horizontalalignment='right',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.savefig('mst_visualization.png', dpi=300, bbox_inches='tight')
    print("\n✓ Saved: mst_visualization.png")
    plt.show()
    
    # === ADDITIONAL: Network Statistics ===
    print("\n" + "="*70)
    print("NETWORK STATISTICS")
    print("="*70)
    
    print(f"\nNodes: {len(MST.nodes())}")
    print(f"Edges: {len(MST.edges())}")
    print(f"Density: {nx.density(MST):.4f}")
    print(f"Is Connected: {nx.is_connected(MST)}")
    
    if nx.is_connected(MST):
        print(f"Diameter: {nx.diameter(MST)}")
        print(f"Average Path Length: {nx.average_shortest_path_length(MST):.2f}")
    
    # Node centrality
    print(f"\nTop 3 Central Nodes (Degree Centrality):")
    top_nodes = sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:3]
    for node, cent in top_nodes:
        print(f"  {node}: {cent:.4f}")
    
    # Edge statistics
    print(f"\nEdge Distance Statistics:")
    print(f"  Min: {min(distances):.2f}")
    print(f"  Max: {max(distances):.2f}")
    print(f"  Mean: {np.mean(distances):.2f}")
    print(f"  Median: {np.median(distances):.2f}")
    print(f"  Std Dev: {np.std(distances):.2f}")
    
    print("\n" + "="*70)

def create_enhanced_report_viz(report_file="groceries_hybrid_report.csv"):
    """
    Create enhanced visualization with support vs distance
    """
    df = pd.read_csv(report_file)
    df_direct = df[df['Pattern'] == 'Direct'].copy()
    
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Scatter plot: Support vs Distance
    colors = {'Fundamental': 'green', 'Moderate': 'orange', 'Niche': 'red'}
    
    for category in df_direct['Category'].unique():
        subset = df_direct[df_direct['Category'] == category]
        ax.scatter(subset['Support'], subset['Distance'],
                  label=category,
                  color=colors.get(category, 'gray'),
                  s=200,
                  alpha=0.7,
                  edgecolors='black',
                  linewidth=1.5)
    
    # Add labels for each point
    for _, row in df_direct.iterrows():
        itemset_str = row['Itemset'].strip('{}')
        items = [item.strip() for item in itemset_str.split(',')]
        label = f"{items[0][:15]}\n{items[1][:15]}"
        
        ax.annotate(label,
                   (row['Support'], row['Distance']),
                   fontsize=8,
                   ha='center',
                   bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.3))
    
    ax.set_xlabel('Support (Co-occurrence Frequency)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Structural Distance', fontsize=12, fontweight='bold')
    ax.set_title('Support vs Structural Distance\n(Fundamental bundles = High support + Low distance)',
                 fontsize=14, fontweight='bold')
    ax.legend(fontsize=11, loc='upper right')
    ax.grid(True, alpha=0.3, linestyle='--')
    
    # Add threshold lines
    ax.axhline(y=20, color='green', linestyle='--', alpha=0.5, label='Fundamental threshold')
    ax.axhline(y=40, color='orange', linestyle='--', alpha=0.5, label='Moderate threshold')
    
    plt.tight_layout()
    plt.savefig('support_vs_distance.png', dpi=300, bbox_inches='tight')
    print("✓ Saved: support_vs_distance.png")
    plt.show()

# ====================================
# MAIN
# ====================================
if __name__ == "__main__":
    print("="*70)
    print("MST VISUALIZATION TOOL")
    print("="*70)
    
    # Check if report exists
    import os
    report_file = "groceries_hybrid_report.csv"
    
    if not os.path.exists(report_file):
        print(f"\n❌ Error: {report_file} not found!")
        print("Please run hybrid_mba_fixed.py first to generate the report.")
    else:
        # Visualize MST
        visualize_mst_from_report(report_file)
        
        # Create enhanced viz
        print("\nCreating enhanced visualization...")
        create_enhanced_report_viz(report_file)
        
        print("\n✓ All visualizations complete!")
        print("\nGenerated files:")
        print("  1. mst_visualization.png - MST network + distance distribution")
        print("  2. support_vs_distance.png - Support vs distance scatter plot")
======================================================================
MST VISUALIZATION TOOL
======================================================================
======================================================================
MST VISUALIZATION
======================================================================
Loaded 5 MST edges from report
MST Graph: 6 nodes, 5 edges

✓ Saved: mst_visualization.png
No description has been provided for this image
======================================================================
NETWORK STATISTICS
======================================================================

Nodes: 6
Edges: 5
Density: 0.3333
Is Connected: True
Diameter: 2
Average Path Length: 1.67

Top 3 Central Nodes (Degree Centrality):
  whole milk: 1.0000
  other vegetables: 0.2000
  rolls/buns: 0.2000

Edge Distance Statistics:
  Min: 10.43
  Max: 18.45
  Mean: 14.51
  Median: 13.93
  Std Dev: 2.65

======================================================================

Creating enhanced visualization...
✓ Saved: support_vs_distance.png
No description has been provided for this image
✓ All visualizations complete!

Generated files:
  1. mst_visualization.png - MST network + distance distribution
  2. support_vs_distance.png - Support vs distance scatter plot
In [ ]: