"""
api_usage_tutorial.py

PalmLab API Tutorial
A comprehensive guide to using the RESTful API with interactive examples.

Author: PalmLab Team
Date: 2025
Version: 1.2

This tutorial demonstrates how to use all API endpoints with:
1. Interactive command-line interface
2. Parameter validation and error handling
3. Multiple output formats (JSON, CSV, PNG)
4. Download options
5. Example data and use cases

Usage:
    python api_usage_tutorial.py

Requirements:
    - Python 3.7+
    - requests library (pip install requests)
    - pandas library for data analysis (pip install pandas)
    - matplotlib for visualization (pip install matplotlib)
"""

import requests
import pandas as pd
import json
import time
import os
import sys
from datetime import datetime
from typing import Dict, List, Optional, Union, Tuple
import matplotlib.pyplot as plt
from io import BytesIO
import base64

# ==================== Configuration ====================
BASE_URL = "https://palmlab.intelligent-oncology.com/api"

# Example data for testing
EXAMPLE_DATA = {
    "tools1": {
        "proteins": ["P01112", "P07900", "A0A0J9YXG8", "P17252", "P78460"],
        "species": "human",
        "mode": "default",
        "description": "Differential expression analysis of cancer vs normal"
    },
    "tools2": {
        "protein": "P19096",
        "species": "Mouse",
        "tissue": "All",
        "description": "Protein interaction network for P19096 (Mouse)"
    },
    "tools3": {
        "protein1": "P19096",
        "protein2": "Q01279",
        "species": "Mouse",
        "tissue": "All",
        "description": "Protein pair relationship analysis"
    },
    "tools4_protein": {
        "protein_query": "P01116",  # KRAS
        "mutation_type_choice": "SNV",
        "page": 1,
        "page_size": 20,
        "description": "Search mutations in KRAS protein"
    },
    "tools4_gene": {
        "mutation_gene_query": "TP53",
        "mutation_type_choice": "All",
        "page": 1,
        "page_size": 20,
        "description": "Search TP53 gene mutations"
    },
    "tools5": {
        "proteins": ["Q3U6Q4", "Q61409", "Q9EQL1", "P63085"],
        "species": "Mouse",
        "description": "Multi-protein expression analysis"
    },
    "tools6": {
        "proteins": ["P01112", "P07900", "A0A0J9YXG8"],
        "species": "human",
        "window_size": 6,
        "data_sources": "experimental,database,prediction",
        "analysis_method": "frequency",
        "description": "Motif pattern analysis"
    }
}

# Tools that support image output
TOOLS_WITH_IMAGES = {
    "tools3": "Protein pair heatmap",
    "tools5": "Expression heatmap and PCA",
    "tools6": "Motif diagram"
}

# ==================== Utility Functions ====================
def print_header(title: str):
    """Print formatted header"""
    print("\n" + "="*80)
    print(f"{title:^80}")
    print("="*80)

def print_success(message: str):
    """Print success message in green"""
    print(f"\033[92m✓ {message}\033[0m")

def print_error(message: str):
    """Print error message in red"""
    print(f"\033[91m✗ {message}\033[0m")

def print_warning(message: str):
    """Print warning message in yellow"""
    print(f"\033[93m⚠ {message}\033[0m")

def print_info(message: str):
    """Print info message in blue"""
    print(f"\033[94mℹ {message}\033[0m")

def validate_protein_list(proteins_str: str) -> List[str]:
    """Validate and parse protein list input"""
    proteins = [p.strip() for p in proteins_str.split(',') if p.strip()]
    if not proteins:
        raise ValueError("No valid proteins provided")
    return proteins

def make_api_request(endpoint: str, params: Dict, timeout: int = 30) -> Optional[Dict]:
    """Make API request with error handling"""
    try:
        url = f"{BASE_URL}/{endpoint}"
        print_info(f"Request URL: {url}")
        print_info(f"Parameters: {params}")
        
        response = requests.get(url, params=params, timeout=timeout)
        
        if response.status_code == 200:
            return response.json()
        else:
            print_error(f"HTTP Error {response.status_code}: {response.text}")
            return None
            
    except requests.exceptions.Timeout:
        print_error("Request timed out")
        return None
    except requests.exceptions.ConnectionError:
        print_error("Connection error - check server availability")
        return None
    except requests.exceptions.RequestException as e:
        print_error(f"Request failed: {e}")
        return None

def save_json(data: Dict, filename: str):
    """Save JSON data to file"""
    with open(filename, 'w') as f:
        json.dump(data, f, indent=2)
    print_success(f"JSON saved to {filename}")

def save_csv(content: str, filename: str):
    """Save CSV content to file"""
    with open(filename, 'w', encoding='utf-8') as f:
        f.write(content)
    print_success(f"CSV saved to {filename}")

def save_image(image_data: bytes, filename: str):
    """Save image data to file"""
    with open(filename, 'wb') as f:
        f.write(image_data)
    print_success(f"Image saved to {filename}")

def display_results_summary(data: Dict):
    """Display summary of API results"""
    if not data.get('success'):
        print_error(f"API Error: {data.get('message', 'Unknown error')}")
        return
    
    print_success(f"Success: {data.get('message', '')}")
    
    # Display query info
    query = data.get('query', {})
    if query:
        print_info("Query Information:")
        for key, value in query.items():
            if isinstance(value, list):
                print(f"  {key}: {', '.join(map(str, value[:5]))}{'...' if len(value) > 5 else ''}")
            else:
                print(f"  {key}: {value}")
    
    # Display summary
    summary = data.get('summary', {})
    if summary:
        print_info("Summary:")
        for key, value in summary.items():
            print(f"  {key}: {value}")
    
    # Display results count
    results = data.get('results')
    if isinstance(results, list):
        print_info(f"Results count: {len(results)}")
    elif isinstance(results, dict):
        results_data = results.get('records', [])
        if isinstance(results_data, list):
            print_info(f"Results count: {len(results_data)}")

# ==================== API Functions ====================
def tools1_differential_analysis(params: Dict, download: bool = False, format_type: str = 'json'):
    """
    Tools1: Differential Expression Analysis
    
    Parameters:
    - proteins: Comma-separated list of protein IDs or gene symbols (REQUIRED)
    - species: human or mouse (default: human)
    - mode: default, cancer_vs_normal, or custom (default: default)
    - group_a_datasets: Comma-separated list (required for custom mode)
    - group_b_datasets: Comma-separated list (required for custom mode)
    - group_a_label: Custom label for group A (optional)
    - group_b_label: Custom label for group B (optional)
    """
    print_header("Tools1: Differential Expression Analysis")
    
    # Validate required parameters
    if not params.get('proteins'):
        print_error("Error: 'proteins' parameter is required")
        return None
    
    # Set default parameters
    params.setdefault('species', 'human')
    params.setdefault('mode', 'default')
    params.setdefault('format', format_type)
    if download:
        params['download'] = 'true'
    
    # Validate mode-specific parameters
    mode = params.get('mode', 'default')
    if mode == 'custom':
        if not params.get('group_a_datasets') or not params.get('group_b_datasets'):
            print_error("Error: Both 'group_a_datasets' and 'group_b_datasets' are required for custom mode")
            return None
    
    # Make API request
    data = make_api_request('tools1/differential/', params)
    
    if data and data.get('success'):
        display_results_summary(data)
        
        # Save results if requested
        if download:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            mode = params.get('mode', 'default')
            
            if format_type == 'json':
                filename = f"tools1_differential_{mode}_{timestamp}.json"
                save_json(data, filename)
            elif format_type == 'csv':
                # For CSV, we need to make a separate request
                params['format'] = 'csv'
                response = requests.get(f"{BASE_URL}/tools1/differential/", params=params)
                if response.status_code == 200:
                    filename = f"tools1_differential_{mode}_{timestamp}.csv"
                    save_csv(response.text, filename)
    
    return data

def tools2_protein_interactions(params: Dict, download: bool = False, format_type: str = 'json'):
    """
    Tools2: Protein Interaction Network
    
    Parameters:
    - protein: Single protein ID
    - species: Mouse or Human (default: Mouse)
    - tissue: All, Brain, Liver (for Mouse) or Tumor, Normal (for Human)
    """
    print_header("Tools2: Protein Interaction Network")
    
    # Validate required parameters
    if not params.get('protein'):
        print_error("Error: 'protein' parameter is required")
        return None
    
    # Set default parameters
    params.setdefault('species', 'Mouse')
    params.setdefault('tissue', 'All')
    params.setdefault('format', format_type)
    if download:
        params['download'] = 'true'
    
    # Make API request
    data = make_api_request('tools2/network/', params)
    
    if data and data.get('success'):
        display_results_summary(data)
        
        # Save results if requested
        if download:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            protein = params.get('protein', 'unknown')
            
            if format_type == 'json':
                filename = f"tools2_interactions_{protein}_{timestamp}.json"
                save_json(data, filename)
            elif format_type == 'csv':
                params['format'] = 'csv'
                response = requests.get(f"{BASE_URL}/tools2/network/", params=params)
                if response.status_code == 200:
                    filename = f"tools2_interactions_{protein}_{timestamp}.csv"
                    save_csv(response.text, filename)
    
    return data

def tools3_protein_pair(params: Dict, download: bool = False, format_type: str = 'json', get_image: bool = False):
    """
    Tools3: Protein Pair Relationship Analysis
    
    Parameters:
    - protein1: First protein ID
    - protein2: Second protein ID
    - species: Mouse or Human (default: Mouse)
    - tissue: Tissue type
    - output_image: true/false (for heatmap)
    """
    print_header("Tools3: Protein Pair Relationship Analysis")
    
    # Validate required parameters
    if not params.get('protein1') or not params.get('protein2'):
        print_error("Error: Both 'protein1' and 'protein2' parameters are required")
        return None
    
    # Set default parameters
    params.setdefault('species', 'Mouse')
    params.setdefault('tissue', 'All')
    
    # Handle image request
    if get_image:
        params['output_image'] = 'true'
        params['format'] = 'image'
        if download:
            params['download'] = 'true'
        
        # Make direct image request
        response = requests.get(f"{BASE_URL}/tools3/pair/", params=params, stream=True)
        if response.status_code == 200:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            protein1 = params.get('protein1', 'protein1')
            protein2 = params.get('protein2', 'protein2')
            filename = f"tools3_heatmap_{protein1}_{protein2}_{timestamp}.png"
            save_image(response.content, filename)
            print_success(f"Heatmap image downloaded: {filename}")
            return True
        else:
            print_error(f"Failed to download image: {response.status_code}")
            return False
    else:
        params['format'] = format_type
        if download:
            params['download'] = 'true'
        
        # Make regular API request
        data = make_api_request('tools3/pair/', params)
        
        if data and data.get('success'):
            display_results_summary(data)
            
            # Save results if requested
            if download:
                timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
                protein1 = params.get('protein1', 'protein1')
                protein2 = params.get('protein2', 'protein2')
                
                if format_type == 'json':
                    filename = f"tools3_pair_{protein1}_{protein2}_{timestamp}.json"
                    save_json(data, filename)
                elif format_type == 'csv':
                    params['format'] = 'csv'
                    response = requests.get(f"{BASE_URL}/tools3/pair/", params=params)
                    if response.status_code == 200:
                        filename = f"tools3_pair_{protein1}_{protein2}_{timestamp}.csv"
                        save_csv(response.text, filename)
        
        return data

def tools4_search_protein_mutations(params: Dict, download: bool = False, format_type: str = 'json'):
    """
    Tools4: Search Protein Mutations
    
    Parameters:
    - protein_query: Protein ID or gene symbol
    - mutation_type_choice: CNV_Amplification, CNV_Deletion, Hotspot_gene, SNV, All
    - page: Page number (default: 1)
    - page_size: Records per page (default: 20)
    """
    print_header("Tools4: Search Protein Mutations")
    
    # Validate required parameters
    if not params.get('protein_query'):
        print_error("Error: 'protein_query' parameter is required")
        return None
    
    # Set default parameters
    params.setdefault('mutation_type_choice', 'All')
    params.setdefault('page', 1)
    params.setdefault('page_size', 20)
    params.setdefault('format', format_type)
    if download:
        params['download'] = 'true'
    
    # Make API request
    data = make_api_request('tools4/searchpalmitoylatedprotein/', params)
    
    if data and data.get('success'):
        display_results_summary(data)
        
        # Save results if requested
        if download:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            query = params.get('protein_query', 'mutations')
            
            if format_type == 'json':
                filename = f"tools4_protein_mutations_{query}_{timestamp}.json"
                save_json(data, filename)
            elif format_type == 'csv':
                params['format'] = 'csv'
                response = requests.get(f"{BASE_URL}/tools4/searchpalmitoylatedprotein/", params=params)
                if response.status_code == 200:
                    filename = f"tools4_protein_mutations_{query}_{timestamp}.csv"
                    save_csv(response.text, filename)
    
    return data

def tools4_search_gene_mutations(params: Dict, download: bool = False, format_type: str = 'json'):
    """
    Tools4: Search Gene Mutations
    
    Parameters:
    - mutation_gene_query: Gene name
    - mutation_type_choice: CNV_Amplification, CNV_Deletion, Hotspot_gene, SNV, All
    - page: Page number (default: 1)
    - page_size: Records per page (default: 20)
    """
    print_header("Tools4: Search Gene Mutations")
    
    # Validate required parameters
    if not params.get('mutation_gene_query'):
        print_error("Error: 'mutation_gene_query' parameter is required")
        return None
    
    # Set default parameters
    params.setdefault('mutation_type_choice', 'All')
    params.setdefault('page', 1)
    params.setdefault('page_size', 20)
    params.setdefault('format', format_type)
    if download:
        params['download'] = 'true'
    
    # Make API request
    data = make_api_request('tools4/searchgenemutation/', params)
    
    if data and data.get('success'):
        display_results_summary(data)
        
        # Save results if requested
        if download:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            query = params.get('mutation_gene_query', 'mutations')
            
            if format_type == 'json':
                filename = f"tools4_gene_mutations_{query}_{timestamp}.json"
                save_json(data, filename)
            elif format_type == 'csv':
                params['format'] = 'csv'
                response = requests.get(f"{BASE_URL}/tools4/searchgenemutation/", params=params)
                if response.status_code == 200:
                    filename = f"tools4_gene_mutations_{query}_{timestamp}.csv"
                    save_csv(response.text, filename)
    
    return data

def tools5_multi_protein_analysis(params: Dict, download: bool = False, format_type: str = 'json', 
                                  get_image: bool = False, include_pca: bool = False):
    """
    Tools5: Multi-Protein Expression Analysis
    
    Parameters:
    - proteins: Comma-separated list of protein IDs
    - species: Mouse or Human (default: Mouse)
    - output_image: true/false (for heatmap)
    - include_pca: true/false (include PCA analysis)
    """
    print_header("Tools5: Multi-Protein Expression Analysis")
    
    # Validate required parameters
    if not params.get('proteins'):
        print_error("Error: 'proteins' parameter is required")
        return None
    
    # Set default parameters
    params.setdefault('species', 'Mouse')
    
    # Handle image request
    if get_image:
        params['output_image'] = 'true'
        if include_pca:
            params['include_pca'] = 'true'
        params['format'] = 'image'
        if download:
            params['download'] = 'true'
        
        # Make direct image request
        response = requests.get(f"{BASE_URL}/tools5/multi/", params=params, stream=True)
        if response.status_code == 200:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            image_type = 'pca' if include_pca else 'heatmap'
            filename = f"tools5_{image_type}_{timestamp}.png"
            save_image(response.content, filename)
            print_success(f"Image downloaded: {filename}")
            return True
        else:
            print_error(f"Failed to download image: {response.status_code}")
            return False
    else:
        params['format'] = format_type
        if download:
            params['download'] = 'true'
        if include_pca:
            params['include_pca'] = 'true'
        
        # Make regular API request
        data = make_api_request('tools5/multi/', params)
        
        if data and data.get('success'):
            display_results_summary(data)
            
            # Save results if requested
            if download:
                timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
                protein_count = len(params.get('proteins', '').split(','))
                
                if format_type == 'json':
                    filename = f"tools5_expression_{protein_count}_proteins_{timestamp}.json"
                    save_json(data, filename)
                elif format_type == 'csv':
                    params['format'] = 'csv'
                    response = requests.get(f"{BASE_URL}/tools5/multi/", params=params)
                    if response.status_code == 200:
                        filename = f"tools5_expression_{protein_count}_proteins_{timestamp}.csv"
                        save_csv(response.text, filename)
        
        return data

def tools6_motif_analysis(params: Dict, download: bool = False, format_type: str = 'image'):
    """
    Tools6: Motif Pattern Analysis
    
    Parameters:
    - proteins: Comma-separated list of protein IDs
    - species: human or mouse (default: human)
    - window_size: Window size for motif (default: 6)
    - data_sources: experimental,database,prediction (comma-separated)
    - analysis_method: frequency or information (default: frequency)
    """
    print_header("Tools6: Motif Pattern Analysis")
    
    # Validate required parameters
    if not params.get('proteins'):
        print_error("Error: 'proteins' parameter is required")
        return None
    
    # Set default parameters
    params.setdefault('species', 'human')
    params.setdefault('window_size', 6)
    params.setdefault('data_sources', 'experimental,database,prediction')
    params.setdefault('analysis_method', 'frequency')
    params.setdefault('format', format_type)
    if download:
        params['download'] = 'true'
    
    # Handle image request (default for Tools6)
    if format_type == 'image':
        # Make direct image request
        response = requests.get(f"{BASE_URL}/tools6/motif/", params=params, stream=True)
        if response.status_code == 200:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            protein_count = len(params.get('proteins', '').split(','))
            filename = f"tools6_motif_{protein_count}_proteins_{timestamp}.png"
            save_image(response.content, filename)
            print_success(f"Motif image downloaded: {filename}")
            return True
        else:
            print_error(f"Failed to download image: {response.status_code}")
            return False
    else:
        # Make regular API request
        data = make_api_request('tools6/motif/', params)
        
        if data and data.get('success'):
            display_results_summary(data)
            
            # Save results if requested
            if download:
                timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
                protein_count = len(params.get('proteins', '').split(','))
                
                if format_type == 'json':
                    filename = f"tools6_motif_{protein_count}_proteins_{timestamp}.json"
                    save_json(data, filename)
                elif format_type == 'csv':
                    params['format'] = 'csv'
                    response = requests.get(f"{BASE_URL}/tools6/motif/", params=params)
                    if response.status_code == 200:
                        filename = f"tools6_motif_{protein_count}_proteins_{timestamp}.csv"
                        save_csv(response.text, filename)
        
        return data

# ==================== Interactive Menu ====================
def display_menu():
    """Display main menu"""
    print_header("PalmLab API Tutorial")
    print("Select an API function to test:")
    print("1. Tools1 - Differential Expression Analysis")
    print("2. Tools2 - Protein Interaction Network")
    print("3. Tools3 - Protein Pair Relationship")
    print("4. Tools4a - Search Protein Mutations")
    print("5. Tools4b - Search Gene Mutations")
    print("6. Tools5 - Multi-Protein Expression Analysis")
    print("7. Tools6 - Motif Pattern Analysis")
    print("8. Run All Examples (Quick Test)")
    print("9. Exit")
    print("\nNote: All examples use pre-defined test data.")
    print("You can modify parameters in the EXAMPLE_DATA dictionary.")

def get_user_choice() -> int:
    """Get user choice with validation"""
    while True:
        try:
            choice = input("\nEnter your choice (1-9): ").strip()
            if not choice:
                continue
            choice = int(choice)
            if 1 <= choice <= 9:
                return choice
            else:
                print_error("Please enter a number between 1 and 9")
        except ValueError:
            print_error("Please enter a valid number")

def get_yes_no(prompt: str, default: bool = False) -> bool:
    """Get yes/no input from user"""
    while True:
        default_text = 'Y/n' if default else 'y/N'
        response = input(f"{prompt} ({default_text}): ").strip().lower()
        if not response:
            return default
        if response in ['y', 'yes']:
            return True
        elif response in ['n', 'no']:
            return False
        else:
            print_error("Please enter 'y' or 'n'")

def get_format_choice(tool_name: str = None):
    """Get output format choice"""
    print("\nSelect output format:")
    print("1. JSON (default)")
    print("2. CSV")
    
    # Only show image option for tools that support it
    if tool_name and tool_name in TOOLS_WITH_IMAGES:
        print(f"3. Image - {TOOLS_WITH_IMAGES[tool_name]}")
        max_choice = 3
    else:
        max_choice = 2
    
    while True:
        choice = input(f"Enter format choice (1-{max_choice}, default=1): ").strip()
        if not choice:
            return 'json'
        
        if choice == '1':
            return 'json'
        elif choice == '2':
            return 'csv'
        elif choice == '3' and max_choice == 3:
            return 'image'
        else:
            print_error(f"Please enter a number between 1 and {max_choice}")

def get_custom_or_example(tool_name: str) -> Tuple[Dict, bool]:
    """Get parameters from user or use example"""
    use_custom = get_yes_no("Use custom parameters instead of example?", False)
    
    if use_custom:
        return get_custom_parameters(tool_name), True
    else:
        print_info(f"Using example parameters for {tool_name}")
        example = EXAMPLE_DATA.get(tool_name, {})
        
        # Handle list to string conversion
        if 'proteins' in example and isinstance(example['proteins'], list):
            example = example.copy()
            example['proteins'] = ','.join(example['proteins'])
        
        return example, False

def get_tools1_mode_choice() -> str:
    """Get Tools1 mode choice from user"""
    print_info("\nSelect Tools1 analysis mode:")
    print("1. Default mode (auto grouping based on species)")
    print("2. Cancer vs Normal mode (human only)")
    print("3. Custom mode (user-defined groups)")
    
    while True:
        mode_choice = input("Enter mode choice (1-3, default=1): ").strip() or '1'
        
        if mode_choice == '1':
            return 'default'
        elif mode_choice == '2':
            return 'cancer_vs_normal'
        elif mode_choice == '3':
            return 'custom'
        else:
            print_error("Please enter 1, 2, or 3")

def get_tools1_datasets(species: str) -> List[str]:
    """Get available datasets for a species (simplified version)"""
    if species == 'human':
        return [
            'Jurkat_T_cells', 'U937_cells', 'DU145_cells', 'HeLa_cells', 'LNCaP_cells', 'PC3_cells',
            'T_cells', 'HAP1_cells', '293T_cells', 'CEMx174_cells', 'Endothelial_cells',
            'prefrontal_cortex', 'liver_membrane', 'cerebral_cortex'
        ]
    else:  # mouse
        return [
            'brain_tissue', 'liver_tissue', 'Macrophage_Raw_264.7', 'NSC', 'testis', 'liver_membrane'
        ]

def get_custom_parameters(tool_name: str) -> Dict:
    """Get custom parameters from user with validation"""
    params = {}
    
    if tool_name == "tools1":
        print_info("Enter custom parameters for Tools1:")
        
        # Get analysis mode
        mode = get_tools1_mode_choice()
        
        # Get proteins
        proteins = input("Proteins (comma-separated, e.g., P01112,P07900): ").strip()
        if not proteins:
            print_warning("No proteins entered, using example data")
            example = EXAMPLE_DATA['tools1'].copy()
            example['proteins'] = ','.join(example['proteins'])
            example['mode'] = mode
            return example
        
        # Get species
        species = input("Species (human/mouse, default=human): ").strip() or 'human'
        
        params = {'proteins': proteins, 'species': species, 'mode': mode}
        
        # Validate mode-specific requirements
        if mode == 'cancer_vs_normal' and species != 'human':
            print_warning("Cancer vs Normal mode is only available for human species.")
            print_warning("Changing species to human.")
            params['species'] = 'human'
        
        elif mode == 'custom':
            # Get available datasets for the species
            available_datasets = get_tools1_datasets(species)
            print_info(f"Available datasets for {species}: {', '.join(available_datasets)}")
            
            # Get group A datasets
            while True:
                group_a_input = input("Group A datasets (comma-separated): ").strip()
                group_a_datasets = [ds.strip() for ds in group_a_input.split(',') if ds.strip()]
                
                if not group_a_datasets:
                    print_error("Please enter at least one dataset for Group A")
                    continue
                
                # Check if all datasets are valid
                invalid_datasets = [ds for ds in group_a_datasets if ds not in available_datasets]
                if invalid_datasets:
                    print_error(f"Invalid datasets in Group A: {', '.join(invalid_datasets)}")
                    print_info(f"Available datasets: {', '.join(available_datasets)}")
                    continue
                
                params['group_a_datasets'] = ','.join(group_a_datasets)
                break
            
            # Get group B datasets
            while True:
                group_b_input = input("Group B datasets (comma-separated): ").strip()
                group_b_datasets = [ds.strip() for ds in group_b_input.split(',') if ds.strip()]
                
                if not group_b_datasets:
                    print_error("Please enter at least one dataset for Group B")
                    continue
                
                # Check if all datasets are valid
                invalid_datasets = [ds for ds in group_b_datasets if ds not in available_datasets]
                if invalid_datasets:
                    print_error(f"Invalid datasets in Group B: {', '.join(invalid_datasets)}")
                    print_info(f"Available datasets: {', '.join(available_datasets)}")
                    continue
                
                # Check for overlap between groups
                overlap = set(group_a_datasets) & set(group_b_datasets)
                if overlap:
                    print_error(f"Datasets cannot be in both groups. Overlap: {', '.join(overlap)}")
                    continue
                
                params['group_b_datasets'] = ','.join(group_b_datasets)
                break
            
            # Get optional labels
            group_a_label = input("Group A label (optional): ").strip()
            group_b_label = input("Group B label (optional): ").strip()
            
            if group_a_label:
                params['group_a_label'] = group_a_label
            if group_b_label:
                params['group_b_label'] = group_b_label
    
    elif tool_name == "tools2":
        print_info("Enter custom parameters for Tools2:")
        protein = input("Protein ID (e.g., P19096): ").strip()
        if not protein:
            print_warning("No protein entered, using example data")
            return EXAMPLE_DATA['tools2']
        
        species = input("Species (Mouse/Human, default=Mouse): ").strip() or 'Mouse'
        tissue = input("Tissue (All/Brain/Liver for Mouse, Tumor/Normal for Human, default=All): ").strip() or 'All'
        
        params = {'protein': protein, 'species': species, 'tissue': tissue}
    
    elif tool_name == "tools3":
        print_info("Enter custom parameters for Tools3:")
        protein1 = input("First protein ID (e.g., P19096): ").strip()
        protein2 = input("Second protein ID (e.g., Q01279): ").strip()
        
        if not protein1 or not protein2:
            print_warning("Missing protein IDs, using example data")
            return EXAMPLE_DATA['tools3']
        
        species = input("Species (Mouse/Human, default=Mouse): ").strip() or 'Mouse'
        tissue = input("Tissue (All/Brain/Liver for Mouse, Tumor/Normal for Human, default=All): ").strip() or 'All'
        
        params = {'protein1': protein1, 'protein2': protein2, 'species': species, 'tissue': tissue}
    
    elif tool_name == "tools4_protein":
        print_info("Enter custom parameters for Tools4 Protein Search:")
        protein_query = input("Protein ID or gene symbol (e.g., P01116 or KRAS): ").strip()
        if not protein_query:
            print_warning("No query entered, using example data")
            return EXAMPLE_DATA['tools4_protein']
        
        mutation_type = input("Mutation type (CNV_Amplification, CNV_Deletion, Hotspot_gene, SNV, All, default=All): ").strip() or 'All'
        page = input("Page number (default=1): ").strip() or '1'
        page_size = input("Page size (default=20): ").strip() or '20'
        
        params = {
            'protein_query': protein_query,
            'mutation_type_choice': mutation_type,
            'page': page,
            'page_size': page_size
        }
    
    elif tool_name == "tools4_gene":
        print_info("Enter custom parameters for Tools4 Gene Search:")
        mutation_gene_query = input("Gene name (e.g., TP53): ").strip()
        if not mutation_gene_query:
            print_warning("No gene name entered, using example data")
            return EXAMPLE_DATA['tools4_gene']
        
        mutation_type = input("Mutation type (CNV_Amplification, CNV_Deletion, Hotspot_gene, SNV, All, default=All): ").strip() or 'All'
        page = input("Page number (default=1): ").strip() or '1'
        page_size = input("Page size (default=20): ").strip() or '20'
        
        params = {
            'mutation_gene_query': mutation_gene_query,
            'mutation_type_choice': mutation_type,
            'page': page,
            'page_size': page_size
        }
    
    elif tool_name == "tools5":
        print_info("Enter custom parameters for Tools5:")
        proteins = input("Proteins (comma-separated, e.g., Q3U6Q4,Q61409): ").strip()
        if not proteins:
            print_warning("No proteins entered, using example data")
            example = EXAMPLE_DATA['tools5'].copy()
            example['proteins'] = ','.join(example['proteins'])
            return example
        
        species = input("Species (Mouse/Human, default=Mouse): ").strip() or 'Mouse'
        
        params = {'proteins': proteins, 'species': species}
    
    elif tool_name == "tools6":
        print_info("Enter custom parameters for Tools6:")
        proteins = input("Proteins (comma-separated, e.g., P01112,P07900): ").strip()
        if not proteins:
            print_warning("No proteins entered, using example data")
            example = EXAMPLE_DATA['tools6'].copy()
            example['proteins'] = ','.join(example['proteins'])
            return example
        
        species = input("Species (human/mouse, default=human): ").strip() or 'human'
        window_size = input("Window size (default=6): ").strip() or '6'
        
        params = {
            'proteins': proteins,
            'species': species,
            'window_size': window_size
        }
        
        # Optional parameters
        data_sources = input("Data sources (experimental,database,prediction, default=all): ").strip()
        analysis_method = input("Analysis method (frequency/information, default=frequency): ").strip()
        
        if data_sources:
            params['data_sources'] = data_sources
        if analysis_method:
            params['analysis_method'] = analysis_method
    
    return params

def run_tools1_example():
    """Run Tools1 example"""
    params, is_custom = get_custom_or_example("tools1")
    print_info(f"Using parameters: {params}")
    
    format_type = get_format_choice("tools1")
    download = get_yes_no("Download results?", False)
    
    return tools1_differential_analysis(params, download, format_type)

def run_tools2_example():
    """Run Tools2 example"""
    params, is_custom = get_custom_or_example("tools2")
    print_info(f"Using parameters: {params}")
    
    format_type = get_format_choice("tools2")
    download = get_yes_no("Download results?", False)
    
    return tools2_protein_interactions(params, download, format_type)

def run_tools3_example():
    """Run Tools3 example"""
    params, is_custom = get_custom_or_example("tools3")
    print_info(f"Using parameters: {params}")
    
    format_type = get_format_choice("tools3")
    download = get_yes_no("Download results?", False)
    
    if format_type == 'image':
        get_image = True
    else:
        get_image = get_yes_no("Generate heatmap image?", False)
    
    return tools3_protein_pair(params, download, format_type, get_image)

def run_tools4_protein_example():
    """Run Tools4 protein search example"""
    params, is_custom = get_custom_or_example("tools4_protein")
    print_info(f"Using parameters: {params}")
    
    format_type = get_format_choice("tools4_protein")
    download = get_yes_no("Download results?", False)
    
    return tools4_search_protein_mutations(params, download, format_type)

def run_tools4_gene_example():
    """Run Tools4 gene search example"""
    params, is_custom = get_custom_or_example("tools4_gene")
    print_info(f"Using parameters: {params}")
    
    format_type = get_format_choice("tools4_gene")
    download = get_yes_no("Download results?", False)
    
    return tools4_search_gene_mutations(params, download, format_type)

def run_tools5_example():
    """Run Tools5 example"""
    params, is_custom = get_custom_or_example("tools5")
    print_info(f"Using parameters: {params}")
    
    format_type = get_format_choice("tools5")
    download = get_yes_no("Download results?", False)
    
    if format_type == 'image':
        get_image = True
        include_pca = get_yes_no("Include PCA analysis?", False)
    else:
        get_image = get_yes_no("Generate heatmap image?", False)
        if get_image:
            include_pca = get_yes_no("Include PCA analysis?", False)
        else:
            include_pca = False
    
    return tools5_multi_protein_analysis(params, download, format_type, get_image, include_pca)

def run_tools6_example():
    """Run Tools6 example"""
    params, is_custom = get_custom_or_example("tools6")
    print_info(f"Using parameters: {params}")
    
    format_type = get_format_choice("tools6")
    if format_type == 'image':
        download = get_yes_no("Download motif image?", True)  # Usually want to download images
    else:
        download = get_yes_no("Download results?", False)
    
    return tools6_motif_analysis(params, download, format_type)

def run_all_examples():
    """Run all API examples"""
    print_header("Running All API Examples")
    
    results = {}
    
    try:
        # Tools1
        print_info("\n1. Testing Tools1 - Differential Expression Analysis...")
        params = EXAMPLE_DATA['tools1'].copy()
        params['proteins'] = ','.join(params['proteins'])
        results['tools1'] = tools1_differential_analysis(params)
        time.sleep(1)
        
        # Tools2
        print_info("\n2. Testing Tools2 - Protein Interaction Network...")
        results['tools2'] = tools2_protein_interactions(EXAMPLE_DATA['tools2'])
        time.sleep(1)
        
        # Tools3
        print_info("\n3. Testing Tools3 - Protein Pair Relationship...")
        results['tools3'] = tools3_protein_pair(EXAMPLE_DATA['tools3'])
        time.sleep(1)
        
        # Tools4 Protein
        print_info("\n4. Testing Tools4a - Search Protein Mutations...")
        results['tools4_protein'] = tools4_search_protein_mutations(EXAMPLE_DATA['tools4_protein'])
        time.sleep(1)
        
        # Tools4 Gene
        print_info("\n5. Testing Tools4b - Search Gene Mutations...")
        results['tools4_gene'] = tools4_search_gene_mutations(EXAMPLE_DATA['tools4_gene'])
        time.sleep(1)
        
        # Tools5
        print_info("\n6. Testing Tools5 - Multi-Protein Expression Analysis...")
        params = EXAMPLE_DATA['tools5'].copy()
        params['proteins'] = ','.join(params['proteins'])
        results['tools5'] = tools5_multi_protein_analysis(params)
        time.sleep(1)
        
        # Tools6
        print_info("\n7. Testing Tools6 - Motif Pattern Analysis...")
        params = EXAMPLE_DATA['tools6'].copy()
        params['proteins'] = ','.join(params['proteins'])
        results['tools6'] = tools6_motif_analysis(params, download=True, format_type='image')
        
        # Summary
        print_header("All Examples Completed")
        success_count = 0
        for tool, result in results.items():
            if result:
                if isinstance(result, bool) and result:
                    success_count += 1
                elif isinstance(result, dict) and result.get('success'):
                    success_count += 1
        
        print_success(f"Successfully completed {success_count} out of 7 tests")
        
    except Exception as e:
        print_error(f"Error running examples: {e}")
        import traceback
        traceback.print_exc()
    
    return results

# ==================== Main Function ====================
def main():
    """Main interactive function"""
    print_header("PalmLab API Tutorial")
    print("This tutorial demonstrates how to use the RESTful API.")
    print(f"Base URL: {BASE_URL}")
    print("\nYou can:")
    print("1. Use pre-defined examples")
    print("2. Enter custom parameters")
    print("3. Choose output format (JSON, CSV, or Image)")
    print("4. Download results to files")
    
    while True:
        display_menu()
        choice = get_user_choice()
        
        if choice == 9:
            print_success("Goodbye!")
            break
        
        try:
            if choice == 1:
                run_tools1_example()
            elif choice == 2:
                run_tools2_example()
            elif choice == 3:
                run_tools3_example()
            elif choice == 4:
                run_tools4_protein_example()
            elif choice == 5:
                run_tools4_gene_example()
            elif choice == 6:
                run_tools5_example()
            elif choice == 7:
                run_tools6_example()
            elif choice == 8:
                run_all_examples()
            
            # Ask if user wants to continue
            if choice != 8:  # Don't ask after running all examples
                continue_test = get_yes_no("\nTest another function?", True)
                if not continue_test:
                    print_success("Goodbye!")
                    break
        
        except KeyboardInterrupt:
            print_warning("\nOperation cancelled by user")
            break
        except Exception as e:
            print_error(f"Error: {e}")
            import traceback
            traceback.print_exc()

# ==================== Direct Usage Functions ====================
def test_connection():
    """Test API connection"""
    print_header("Testing API Connection")
    try:
        response = requests.get(BASE_URL, timeout=10)
        if response.status_code == 200:
            print_success(f"API server is reachable at {BASE_URL}")
            return True
        else:
            print_error(f"Server returned status code: {response.status_code}")
            return False
    except Exception as e:
        print_error(f"Cannot connect to API server: {e}")
        return False

def quick_start():
    """Quick start guide"""
    print_header("Quick Start Guide")
    print("""
1. First, test the connection:
   >>> test_connection()
   
2. Explore Tools1 modes:
   # Default mode
   >>> data = tools1_differential_analysis({
           'proteins': 'P01112,P07900',
           'species': 'human',
           'mode': 'default'
       })
   
   # Cancer vs Normal mode (human only)
   >>> data = tools1_differential_analysis({
           'proteins': 'P01112,P07900',
           'species': 'human',
           'mode': 'cancer_vs_normal'
       })
   
   # Custom mode
   >>> data = tools1_differential_analysis({
           'proteins': 'P01112,P07900',
           'species': 'human',
           'mode': 'custom',
           'group_a_datasets': 'Jurkat_T_cells,LNCaP_cells',
           'group_b_datasets': 'T_cells,293T_cells'
       })
   
3. Download results as CSV:
   >>> tools1_differential_analysis({
           'proteins': 'P01112,P07900',
           'mode': 'cancer_vs_normal'
       }, download=True, format_type='csv')
   
4. Get a motif image:
   >>> tools6_motif_analysis({
           'proteins': 'P01112,P07900',
           'species': 'human'
       }, download=True, format_type='image')
   
For more examples, run the interactive tutorial:
   >>> main()
""")

# ==================== Entry Point ====================
if __name__ == "__main__":
    # Check dependencies
    try:
        import requests
        print_success("requests library is installed")
    except ImportError:
        print_error("Please install requests: pip install requests")
        sys.exit(1)
    
    # Test connection first
    if test_connection():
        # Ask user if they want interactive mode or quick test
        print("\nSelect mode:")
        print("1. Interactive tutorial (recommended for beginners)")
        print("2. Quick test all functions")
        print("3. Show quick start guide")
        
        mode_choice = input("Enter choice (1-3, default=1): ").strip() or '1'
        
        if mode_choice == '1':
            main()
        elif mode_choice == '2':
            run_all_examples()
        elif mode_choice == '3':
            quick_start()
        else:
            print_error("Invalid choice, starting interactive mode")
            main()
    else:
        print_warning("Cannot connect to API server. Please check:")
        print("1. The server is running at https://palmlab.intelligent-oncology.com")
        print("2. You have network access to the server")
        print("3. The URL is correct in the BASE_URL variable")
        
        # Still allow running examples for offline testing
        run_offline = get_yes_no("\nRun tutorial in offline mode (examples only)?", False)
        if run_offline:
            main()