In [ ]:
In [1]:
"""
Hybrid MBA-MST v2.1 - IMPROVED with Better Parameters
Changes from v2.0:
- MIN_SUPPORT lowered from 0.05 to 0.015 (1.5%)
- Will capture more patterns (expected: 100-150 itemsets)
- Better graph structure (expected: 80-100 nodes)
- More 3-itemsets for analysis (expected: 20-30)
"""
import os
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from itertools import combinations
from mlxtend.preprocessing import TransactionEncoder
from mlxtend.frequent_patterns import apriori, association_rules
from scipy.stats import spearmanr, mannwhitneyu
import warnings
warnings.filterwarnings('ignore')
# ====================================
# IMPROVED CONFIGURATION
# ====================================
class ImprovedConfig:
def __init__(self):
self.FILE_PATH = "groceries.csv"
# IMPROVED: Lower support threshold
self.MIN_SUPPORT_BASE = 0.015 # 1.5% (was 5%)
self.MIN_CONF_BASE = 0.25 # 25% (was 30%)
# MST options
self.WEIGHT_METHOD = "inverse" # inverse, neglog, sqrt_inv, complement
self.K_MST = 1.0 # 1.0 = standard MST
# Analysis
self.MIN_SUPPORT_TRIAD = 0.015 # Same as base
self.MAX_3_ITEMSETS = 50 # Limit for performance
# Validation
self.VALIDATION_ENABLED = True
self.DISTANCE_THRESHOLD_LOW = 30.0 # Adjusted for new scale
self.DISTANCE_THRESHOLD_HIGH = 60.0
self.VERBOSE = True
# ====================================
# [Reuse same functions from v2.0]
# ====================================
def load_groceries_transactions(file_path):
"""Load transactions from groceries.csv"""
transactions = []
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
for line in f:
line = line.strip()
if not line:
continue
items = [item.strip() for item in line.split(',') if item.strip()]
if len(items) >= 2:
transactions.append(items)
return transactions
def validate_transactions(transactions):
"""Validate and clean transactions"""
transactions = [list(set(tx)) for tx in transactions]
transactions = [tx for tx in transactions if len(tx) >= 2]
stats = {
'total': len(transactions),
'avg_size': np.mean([len(tx) for tx in transactions]),
'median_size': np.median([len(tx) for tx in transactions]),
'max_size': max([len(tx) for tx in transactions]) if transactions else 0,
'min_size': min([len(tx) for tx in transactions]) if transactions else 0
}
return transactions, stats
def calculate_weight(support, method="inverse"):
"""Calculate edge weight"""
eps = 1e-12
if method == "inverse":
return 1.0 / (support + eps)
elif method == "neglog":
return -np.log(support + eps)
elif method == "complement":
return max(0.0, 1.0 - support)
elif method == "sqrt_inv":
return np.sqrt(1.0 / (support + eps))
else:
return 1.0 / (support + eps)
def build_graph_from_rules(rules_pairs, weight_method="inverse"):
"""Build graph from rules"""
G = nx.Graph()
def pair_key(row):
a = list(row['antecedents'])[0]
b = list(row['consequents'])[0]
return tuple(sorted((str(a), str(b))))
rules_pairs = rules_pairs.copy()
rules_pairs['pair'] = rules_pairs.apply(pair_key, axis=1)
pair_support = rules_pairs.groupby('pair')['support'].max()
for (a, b), supp in pair_support.items():
weight = calculate_weight(supp, weight_method)
G.add_edge(a, b, weight=weight, support=supp)
return G
def build_mst(G):
"""Build MST"""
return nx.minimum_spanning_tree(G, weight="weight")
def trace_itemset_in_mst(itemset, MST):
"""Trace itemset in MST"""
nodes = [str(item) for item in itemset]
for u, v in combinations(nodes, 2):
if not nx.has_path(MST, u, v):
return None
edges_used = set()
for u, v in combinations(nodes, 2):
path = nx.shortest_path(MST, source=u, target=v)
for a, b in zip(path[:-1], path[1:]):
edges_used.add(tuple(sorted((a, b))))
total_weight = sum(MST[a][b]['weight'] for a, b in edges_used)
possible_edges = len(list(combinations(nodes, 2)))
compactness = 1.0 - (len(edges_used) / possible_edges)
return {
'bridges': len(edges_used),
'distance': total_weight,
'edges': sorted(list(edges_used)),
'compactness': compactness,
'quality_score': compactness * (1.0 / (total_weight + 0.01))
}
def validate_hypothesis(traces, transactions, frequent_itemsets, config):
"""Run validation tests"""
print("\n" + "="*70)
print("VALIDATION FRAMEWORK")
print("="*70)
results = {}
itemsets_with_distances = [(itemset, trace['distance']) for itemset, trace in traces]
# H1: Correlation
print("\n[TEST 1] Correlation Test")
distances = []
co_occurs = []
for itemset, dist in itemsets_with_distances:
itemset_set = set(map(str, itemset))
count = sum(1 for tx in transactions if itemset_set.issubset(set(tx)))
co_occur_rate = count / len(transactions)
distances.append(dist)
co_occurs.append(co_occur_rate)
if len(distances) >= 3:
corr, p_val = spearmanr(distances, co_occurs)
h1_pass = p_val < 0.05 and corr < 0
print(f" Spearman ρ = {corr:.4f}, p = {p_val:.4f}")
print(f" Status: {'✓ PASSED' if h1_pass else '✗ FAILED'}")
results['h1'] = {'passed': h1_pass, 'correlation': corr, 'p_value': p_val}
else:
print(" ⚠ Not enough data")
results['h1'] = {'passed': False}
# H2: Categorization
print("\n[TEST 2] Categorization Test")
fundamental = []
niche = []
for itemset, dist in itemsets_with_distances:
itemset_set = set(map(str, itemset))
count = sum(1 for tx in transactions if itemset_set.issubset(set(tx)))
co_occur = count / len(transactions)
if dist < config.DISTANCE_THRESHOLD_LOW:
fundamental.append(co_occur)
elif dist > config.DISTANCE_THRESHOLD_HIGH:
niche.append(co_occur)
if len(fundamental) >= 2 and len(niche) >= 2:
stat, p_val = mannwhitneyu(fundamental, niche, alternative='greater')
h2_pass = p_val < 0.05
print(f" Fundamental mean: {np.mean(fundamental):.4f} (n={len(fundamental)})")
print(f" Niche mean: {np.mean(niche):.4f} (n={len(niche)})")
print(f" Mann-Whitney p = {p_val:.4f}")
print(f" Status: {'✓ PASSED' if h2_pass else '✗ FAILED'}")
results['h2'] = {'passed': h2_pass, 'p_value': p_val}
else:
print(" ⚠ Not enough data")
results['h2'] = {'passed': False}
passed = sum([results.get('h1', {}).get('passed', False),
results.get('h2', {}).get('passed', False)])
print("\n" + "="*70)
print(f"OVERALL: {passed}/2 tests passed")
if passed >= 1:
print("🎯 CONCLUSION: Structural distance hypothesis is SUPPORTED")
else:
print("⚠️ CONCLUSION: Need more data or parameter adjustment")
print("="*70)
return results
# ====================================
# MAIN IMPROVED PIPELINE
# ====================================
def run_improved_analysis(config):
"""Run improved analysis with better parameters"""
print("="*70)
print("HYBRID MBA-MST ANALYSIS v2.1 - IMPROVED")
print("="*70)
print(f"Parameters: MIN_SUPPORT={config.MIN_SUPPORT_BASE}, WEIGHT={config.WEIGHT_METHOD}")
# Step 1: Load Data
print("\n[STEP 1] Loading data...")
transactions = load_groceries_transactions(config.FILE_PATH)
transactions, stats = validate_transactions(transactions)
print(f" ✓ Valid transactions: {stats['total']}")
# Step 2: Apriori
print("\n[STEP 2] Running Apriori with lower support threshold...")
te = TransactionEncoder()
te_ary = te.fit(transactions).transform(transactions)
df_bin = pd.DataFrame(te_ary, columns=list(map(str, te.columns_)))
fi = apriori(df_bin, min_support=config.MIN_SUPPORT_BASE, use_colnames=True)
fi['length'] = fi['itemsets'].apply(len)
rules = association_rules(fi, metric="confidence", min_threshold=config.MIN_CONF_BASE)
rules_pairs = rules[
(rules['antecedents'].apply(len) == 1) &
(rules['consequents'].apply(len) == 1)
].copy()
print(f" ✓ Found {len(fi)} frequent itemsets (vs 42 before)")
print(f" - 1-itemsets: {len(fi[fi['length']==1])}")
print(f" - 2-itemsets: {len(fi[fi['length']==2])}")
print(f" - 3-itemsets: {len(fi[fi['length']==3])}")
print(f" - 4+ itemsets: {len(fi[fi['length']>=4])}")
print(f" ✓ Generated {len(rules)} rules ({len(rules_pairs)} pair rules)")
# Step 3: Build Graph & MST
print(f"\n[STEP 3] Building richer graph structure...")
G = build_graph_from_rules(rules_pairs, config.WEIGHT_METHOD)
MST = build_mst(G)
print(f" ✓ Graph: {len(G.nodes())} nodes, {len(G.edges())} edges (vs 6 nodes before)")
print(f" ✓ MST: {len(MST.nodes())} nodes, {len(MST.edges())} edges")
# Check components
components = list(nx.connected_components(G))
print(f" ℹ Connected components: {len(components)}")
if len(components) > 1:
comp_sizes = sorted([len(c) for c in components], reverse=True)
print(f" Component sizes: {comp_sizes[:5]}")
# Step 4: Analyze 3-itemsets
print(f"\n[STEP 4] Analyzing 3-itemsets structural patterns...")
fi_3 = fi[fi['length'] == 3]
# Limit to top N by support for performance
fi_3_top = fi_3.nlargest(config.MAX_3_ITEMSETS, 'support')
traces_3 = []
for itemset in fi_3_top['itemsets']:
trace = trace_itemset_in_mst(itemset, MST)
if trace:
traces_3.append((tuple(sorted(map(str, itemset))), trace))
traces_3_sorted = sorted(traces_3, key=lambda x: x[1]['distance'])[:20]
print(f" ✓ Analyzed {len(traces_3)}/{len(fi_3)} 3-itemsets (vs 0 before)")
if traces_3_sorted:
print(f"\n Top 5 Fundamental Bundles (Lowest Distance):")
for i, (itemset, trace) in enumerate(traces_3_sorted[:5], 1):
support = fi[fi['itemsets'].apply(lambda x: tuple(sorted(map(str, x))) == itemset)]['support'].values[0]
print(f" {i}. {set(itemset)}")
print(f" Support={support:.4f}, Distance={trace['distance']:.2f}, Bridges={trace['bridges']}")
print(f"\n Bottom 3 Niche Patterns (Highest Distance):")
for i, (itemset, trace) in enumerate(reversed(traces_3_sorted[-3:]), 1):
support = fi[fi['itemsets'].apply(lambda x: tuple(sorted(map(str, x))) == itemset)]['support'].values[0]
print(f" {i}. {set(itemset)}")
print(f" Support={support:.4f}, Distance={trace['distance']:.2f}, Bridges={trace['bridges']}")
# Step 5: Validation
if config.VALIDATION_ENABLED and len(traces_3) >= 10:
fi_lookup = {tuple(sorted(map(str, s))): sup
for s, sup in zip(fi['itemsets'], fi['support'])}
validation_results = validate_hypothesis(traces_3, transactions, fi_lookup, config)
else:
print("\n⚠ Validation skipped (not enough patterns)")
validation_results = None
# Step 6: Generate Enhanced Report
print(f"\n[STEP 6] Generating enhanced report...")
report_data = []
# Add MST edges (2-itemsets)
for u, v, data in MST.edges(data=True):
category = 'Fundamental' if data['weight'] < 30 else ('Moderate' if data['weight'] < 60 else 'Niche')
report_data.append({
'Itemset': f"{{{u}, {v}}}",
'Size': 2,
'Support': data['support'],
'Pattern': 'Direct',
'Bridges': 0,
'Distance': data['weight'],
'Compactness': 1.0,
'Category': category
})
# Add 3-itemsets
fi_lookup = {tuple(sorted(map(str, s))): sup
for s, sup in zip(fi['itemsets'], fi['support'])}
for itemset, trace in traces_3_sorted:
support = fi_lookup.get(itemset, np.nan)
category = 'Fundamental' if trace['distance'] < 30 else ('Moderate' if trace['distance'] < 60 else 'Niche')
report_data.append({
'Itemset': str(set(itemset)),
'Size': 3,
'Support': support,
'Pattern': 'Complex',
'Bridges': trace['bridges'],
'Distance': trace['distance'],
'Compactness': trace['compactness'],
'Category': category
})
df_report = pd.DataFrame(report_data)
df_report = df_report.sort_values(['Size', 'Distance']).reset_index(drop=True)
print("\n" + "="*70)
print("ENHANCED PATTERNS REPORT (Top 20)")
print("="*70)
print(df_report.head(20).to_string(index=False))
# Save
output_file = config.FILE_PATH.replace('.csv', '_improved_report.csv')
df_report.to_csv(output_file, index=False)
print(f"\n✓ Full report saved to: {output_file}")
# Summary statistics
print("\n" + "="*70)
print("SUMMARY STATISTICS")
print("="*70)
fundamental = df_report[df_report['Category'] == 'Fundamental']
moderate = df_report[df_report['Category'] == 'Moderate']
niche = df_report[df_report['Category'] == 'Niche']
print(f"\nPattern Distribution:")
print(f" Fundamental: {len(fundamental)} ({len(fundamental)/len(df_report)*100:.1f}%)")
print(f" Moderate: {len(moderate)} ({len(moderate)/len(df_report)*100:.1f}%)")
print(f" Niche: {len(niche)} ({len(niche)/len(df_report)*100:.1f}%)")
print(f"\nDistance Statistics:")
print(f" Overall Mean: {df_report['Distance'].mean():.2f}")
print(f" Fundamental Mean: {fundamental['Distance'].mean():.2f}")
print(f" Moderate Mean: {moderate['Distance'].mean():.2f}")
print(f" Niche Mean: {niche['Distance'].mean():.2f}")
# Conclusion
print("\n" + "="*70)
print("CONCLUSION")
print("="*70)
print("✓ Improved analysis with lower support threshold (1.5%) successfully")
print(" captured MORE patterns while maintaining structural insights:")
print("")
print(f" Before (5% support): 42 itemsets, 6 nodes, 0 3-itemsets")
print(f" After (1.5% support): {len(fi)} itemsets, {len(G.nodes())} nodes, {len(traces_3)} 3-itemsets")
print("")
print(" The hybrid approach now reveals:")
print(f" → {len(fundamental)} FUNDAMENTAL bundles (naturally connected)")
print(f" → {len(moderate)} MODERATE patterns (intermediate structure)")
print(f" → {len(niche)} NICHE patterns (opportunistic co-occurrence)")
print("")
print(" This demonstrates the value of structural context in MBA!")
print("="*70)
return {
'transactions': transactions,
'frequent_itemsets': fi,
'rules': rules,
'graph': G,
'mst': MST,
'traces': traces_3_sorted,
'report': df_report,
'validation': validation_results
}
# ====================================
# MAIN
# ====================================
if __name__ == "__main__":
config = ImprovedConfig()
if not os.path.exists(config.FILE_PATH):
print(f"Error: File '{config.FILE_PATH}' not found!")
else:
results = run_improved_analysis(config)
print("\n✓ Improved analysis completed!")
print("\nNext: Run visualize_mst.py to see the graphs!")
======================================================================
HYBRID MBA-MST ANALYSIS v2.1 - IMPROVED
======================================================================
Parameters: MIN_SUPPORT=0.015, WEIGHT=inverse
[STEP 1] Loading data...
✓ Valid transactions: 7676
[STEP 2] Running Apriori with lower support threshold...
✓ Found 253 frequent itemsets (vs 42 before)
- 1-itemsets: 77
- 2-itemsets: 158
- 3-itemsets: 18
- 4+ itemsets: 0
✓ Generated 136 rules (91 pair rules)
[STEP 3] Building richer graph structure...
✓ Graph: 41 nodes, 89 edges (vs 6 nodes before)
✓ MST: 41 nodes, 40 edges
ℹ Connected components: 1
[STEP 4] Analyzing 3-itemsets structural patterns...
✓ Analyzed 18/18 3-itemsets (vs 0 before)
Top 5 Fundamental Bundles (Lowest Distance):
1. {'whole milk', 'rolls/buns', 'other vegetables'}
Support=0.0229, Distance=24.21, Bridges=2
2. {'whole milk', 'other vegetables', 'yogurt'}
Support=0.0285, Distance=24.36, Bridges=2
3. {'whole milk', 'root vegetables', 'other vegetables'}
Support=0.0297, Distance=26.39, Bridges=2
4. {'whole milk', 'rolls/buns', 'yogurt'}
Support=0.0199, Distance=27.71, Bridges=2
5. {'whole milk', 'tropical fruit', 'other vegetables'}
Support=0.0219, Distance=28.88, Bridges=2
Bottom 3 Niche Patterns (Highest Distance):
1. {'tropical fruit', 'root vegetables', 'other vegetables'}
Support=0.0158, Distance=44.84, Bridges=3
2. {'tropical fruit', 'other vegetables', 'yogurt'}
Support=0.0158, Distance=42.81, Bridges=3
3. {'root vegetables', 'other vegetables', 'yogurt'}
Support=0.0165, Distance=40.32, Bridges=3
======================================================================
VALIDATION FRAMEWORK
======================================================================
[TEST 1] Correlation Test
Spearman ρ = -0.7994, p = 0.0001
Status: ✓ PASSED
[TEST 2] Categorization Test
⚠ Not enough data
======================================================================
OVERALL: 1/2 tests passed
🎯 CONCLUSION: Structural distance hypothesis is SUPPORTED
======================================================================
[STEP 6] Generating enhanced report...
======================================================================
ENHANCED PATTERNS REPORT (Top 20)
======================================================================
Itemset Size Support Pattern Bridges Distance Compactness Category
{other vegetables, whole milk} 2 0.095883 Direct 0 10.429348 1.0 Fundamental
{rolls/buns, whole milk} 2 0.072564 Direct 0 13.780969 1.0 Fundamental
{whole milk, yogurt} 2 0.071782 Direct 0 13.931034 1.0 Fundamental
{root vegetables, whole milk} 2 0.062663 Direct 0 15.958420 1.0 Fundamental
{whole milk, tropical fruit} 2 0.054195 Direct 0 18.451923 1.0 Fundamental
{whole milk, soda} 2 0.051329 Direct 0 19.482234 1.0 Fundamental
{whole milk, bottled water} 2 0.044033 Direct 0 22.710059 1.0 Fundamental
{whole milk, pastry} 2 0.042600 Direct 0 23.474006 1.0 Fundamental
{whole milk, whipped/sour cream} 2 0.041298 Direct 0 24.214511 1.0 Fundamental
{rolls/buns, sausage} 2 0.039213 Direct 0 25.501661 1.0 Fundamental
{whole milk, citrus fruit} 2 0.039083 Direct 0 25.586667 1.0 Fundamental
{whole milk, pip fruit} 2 0.038562 Direct 0 25.932432 1.0 Fundamental
{whole milk, domestic eggs} 2 0.038431 Direct 0 26.020339 1.0 Fundamental
{whole milk, butter} 2 0.035305 Direct 0 28.324723 1.0 Fundamental
{whole milk, newspapers} 2 0.035044 Direct 0 28.535316 1.0 Fundamental
{whole milk, fruit/vegetable juice} 2 0.034132 Direct 0 29.297710 1.0 Fundamental
{whole milk, curd} 2 0.033481 Direct 0 29.867704 1.0 Fundamental
{whole milk, brown bread} 2 0.032308 Direct 0 30.951613 1.0 Moderate
{soda, shopping bags} 2 0.031527 Direct 0 31.719008 1.0 Moderate
{whole milk, margarine} 2 0.031006 Direct 0 32.252101 1.0 Moderate
✓ Full report saved to: groceries_improved_report.csv
======================================================================
SUMMARY STATISTICS
======================================================================
Pattern Distribution:
Fundamental: 25 (43.1%)
Moderate: 30 (51.7%)
Niche: 3 (5.2%)
Distance Statistics:
Overall Mean: 35.28
Fundamental Mean: 24.10
Moderate Mean: 41.84
Niche Mean: 62.83
======================================================================
CONCLUSION
======================================================================
✓ Improved analysis with lower support threshold (1.5%) successfully
captured MORE patterns while maintaining structural insights:
Before (5% support): 42 itemsets, 6 nodes, 0 3-itemsets
After (1.5% support): 253 itemsets, 41 nodes, 18 3-itemsets
The hybrid approach now reveals:
→ 25 FUNDAMENTAL bundles (naturally connected)
→ 30 MODERATE patterns (intermediate structure)
→ 3 NICHE patterns (opportunistic co-occurrence)
This demonstrates the value of structural context in MBA!
======================================================================
✓ Improved analysis completed!
Next: Run visualize_mst.py to see the graphs!
In [2]:
"""
Visualize MST Results from Current Analysis
Run this AFTER running hybrid_mba_fixed.py
to see the graph visualization
"""
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import FancyBboxPatch
def visualize_mst_from_report(report_file="groceries_hybrid_report.csv"):
"""
Visualize MST from the generated report
"""
# Load report
df = pd.read_csv(report_file)
# Filter only direct connections (2-itemsets = MST edges)
df_direct = df[df['Pattern'] == 'Direct'].copy()
print("="*70)
print("MST VISUALIZATION")
print("="*70)
print(f"Loaded {len(df_direct)} MST edges from report")
# Build MST from report
MST = nx.Graph()
for _, row in df_direct.iterrows():
# Parse itemset
itemset_str = row['Itemset']
items = [item.strip() for item in itemset_str.strip('{}').split(',')]
if len(items) == 2:
u, v = items[0].strip(), items[1].strip()
MST.add_edge(u, v,
weight=row['Distance'],
support=row['Support'],
category=row['Category'])
print(f"MST Graph: {len(MST.nodes())} nodes, {len(MST.edges())} edges")
# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8))
# === LEFT PLOT: MST Network ===
pos = nx.spring_layout(MST, k=2, iterations=50, seed=42)
# Node colors by centrality
centrality = nx.degree_centrality(MST)
node_colors = [centrality[node] for node in MST.nodes()]
# Draw network
nx.draw_networkx_nodes(MST, pos,
node_color=node_colors,
node_size=3000,
cmap='YlOrRd',
alpha=0.9,
ax=ax1)
# Edge colors by category
edge_colors = []
edge_widths = []
for u, v, data in MST.edges(data=True):
if data['category'] == 'Fundamental':
edge_colors.append('green')
edge_widths.append(4)
elif data['category'] == 'Moderate':
edge_colors.append('orange')
edge_widths.append(3)
else:
edge_colors.append('red')
edge_widths.append(2)
nx.draw_networkx_edges(MST, pos,
edge_color=edge_colors,
width=edge_widths,
alpha=0.7,
ax=ax1)
# Labels
nx.draw_networkx_labels(MST, pos,
font_size=10,
font_weight='bold',
ax=ax1)
# Edge labels (distance)
edge_labels = {(u, v): f"{d['weight']:.1f}"
for u, v, d in MST.edges(data=True)}
nx.draw_networkx_edge_labels(MST, pos,
edge_labels,
font_size=8,
font_color='blue',
ax=ax1)
ax1.set_title('Minimum Spanning Tree\n(Node size = Centrality, Edge color = Category)',
fontsize=14, fontweight='bold')
ax1.axis('off')
# Legend
from matplotlib.lines import Line2D
legend_elements = [
Line2D([0], [0], color='green', linewidth=4, label='Fundamental (<20)'),
Line2D([0], [0], color='orange', linewidth=3, label='Moderate (20-40)'),
Line2D([0], [0], color='red', linewidth=2, label='Niche (>40)')
]
ax1.legend(handles=legend_elements, loc='upper left', fontsize=10)
# === RIGHT PLOT: Edge Distance Distribution ===
distances = [d['weight'] for u, v, d in MST.edges(data=True)]
categories = [d['category'] for u, v, d in MST.edges(data=True)]
# Histogram
ax2.hist(distances, bins=10, color='steelblue', alpha=0.7, edgecolor='black')
ax2.axvline(np.mean(distances), color='red', linestyle='--', linewidth=2,
label=f'Mean: {np.mean(distances):.2f}')
ax2.axvline(np.median(distances), color='green', linestyle='--', linewidth=2,
label=f'Median: {np.median(distances):.2f}')
ax2.set_xlabel('Structural Distance', fontsize=12)
ax2.set_ylabel('Frequency', fontsize=12)
ax2.set_title('Distribution of Edge Distances', fontsize=14, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)
# Add category counts
from collections import Counter
cat_counts = Counter(categories)
text_str = "Edge Categories:\n"
for cat, count in cat_counts.items():
text_str += f" {cat}: {count}\n"
ax2.text(0.98, 0.98, text_str,
transform=ax2.transAxes,
fontsize=10,
verticalalignment='top',
horizontalalignment='right',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
plt.tight_layout()
plt.savefig('mst_visualization.png', dpi=300, bbox_inches='tight')
print("\n✓ Saved: mst_visualization.png")
plt.show()
# === ADDITIONAL: Network Statistics ===
print("\n" + "="*70)
print("NETWORK STATISTICS")
print("="*70)
print(f"\nNodes: {len(MST.nodes())}")
print(f"Edges: {len(MST.edges())}")
print(f"Density: {nx.density(MST):.4f}")
print(f"Is Connected: {nx.is_connected(MST)}")
if nx.is_connected(MST):
print(f"Diameter: {nx.diameter(MST)}")
print(f"Average Path Length: {nx.average_shortest_path_length(MST):.2f}")
# Node centrality
print(f"\nTop 3 Central Nodes (Degree Centrality):")
top_nodes = sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:3]
for node, cent in top_nodes:
print(f" {node}: {cent:.4f}")
# Edge statistics
print(f"\nEdge Distance Statistics:")
print(f" Min: {min(distances):.2f}")
print(f" Max: {max(distances):.2f}")
print(f" Mean: {np.mean(distances):.2f}")
print(f" Median: {np.median(distances):.2f}")
print(f" Std Dev: {np.std(distances):.2f}")
print("\n" + "="*70)
def create_enhanced_report_viz(report_file="groceries_hybrid_report.csv"):
"""
Create enhanced visualization with support vs distance
"""
df = pd.read_csv(report_file)
df_direct = df[df['Pattern'] == 'Direct'].copy()
fig, ax = plt.subplots(figsize=(12, 8))
# Scatter plot: Support vs Distance
colors = {'Fundamental': 'green', 'Moderate': 'orange', 'Niche': 'red'}
for category in df_direct['Category'].unique():
subset = df_direct[df_direct['Category'] == category]
ax.scatter(subset['Support'], subset['Distance'],
label=category,
color=colors.get(category, 'gray'),
s=200,
alpha=0.7,
edgecolors='black',
linewidth=1.5)
# Add labels for each point
for _, row in df_direct.iterrows():
itemset_str = row['Itemset'].strip('{}')
items = [item.strip() for item in itemset_str.split(',')]
label = f"{items[0][:15]}\n{items[1][:15]}"
ax.annotate(label,
(row['Support'], row['Distance']),
fontsize=8,
ha='center',
bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.3))
ax.set_xlabel('Support (Co-occurrence Frequency)', fontsize=12, fontweight='bold')
ax.set_ylabel('Structural Distance', fontsize=12, fontweight='bold')
ax.set_title('Support vs Structural Distance\n(Fundamental bundles = High support + Low distance)',
fontsize=14, fontweight='bold')
ax.legend(fontsize=11, loc='upper right')
ax.grid(True, alpha=0.3, linestyle='--')
# Add threshold lines
ax.axhline(y=20, color='green', linestyle='--', alpha=0.5, label='Fundamental threshold')
ax.axhline(y=40, color='orange', linestyle='--', alpha=0.5, label='Moderate threshold')
plt.tight_layout()
plt.savefig('support_vs_distance.png', dpi=300, bbox_inches='tight')
print("✓ Saved: support_vs_distance.png")
plt.show()
# ====================================
# MAIN
# ====================================
if __name__ == "__main__":
print("="*70)
print("MST VISUALIZATION TOOL")
print("="*70)
# Check if report exists
import os
report_file = "groceries_hybrid_report.csv"
if not os.path.exists(report_file):
print(f"\n❌ Error: {report_file} not found!")
print("Please run hybrid_mba_fixed.py first to generate the report.")
else:
# Visualize MST
visualize_mst_from_report(report_file)
# Create enhanced viz
print("\nCreating enhanced visualization...")
create_enhanced_report_viz(report_file)
print("\n✓ All visualizations complete!")
print("\nGenerated files:")
print(" 1. mst_visualization.png - MST network + distance distribution")
print(" 2. support_vs_distance.png - Support vs distance scatter plot")
====================================================================== MST VISUALIZATION TOOL ====================================================================== ====================================================================== MST VISUALIZATION ====================================================================== Loaded 5 MST edges from report MST Graph: 6 nodes, 5 edges ✓ Saved: mst_visualization.png
====================================================================== NETWORK STATISTICS ====================================================================== Nodes: 6 Edges: 5 Density: 0.3333 Is Connected: True Diameter: 2 Average Path Length: 1.67 Top 3 Central Nodes (Degree Centrality): whole milk: 1.0000 other vegetables: 0.2000 rolls/buns: 0.2000 Edge Distance Statistics: Min: 10.43 Max: 18.45 Mean: 14.51 Median: 13.93 Std Dev: 2.65 ====================================================================== Creating enhanced visualization... ✓ Saved: support_vs_distance.png
✓ All visualizations complete! Generated files: 1. mst_visualization.png - MST network + distance distribution 2. support_vs_distance.png - Support vs distance scatter plot
In [ ]: