#
# Copyright (c) 2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

import asyncio
import os
import time

from dotenv import load_dotenv
from loguru import logger
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMRunFrame, TTSSpeakFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
from strands import Agent, tool
from strands.models import BedrockModel

load_dotenv(override=True)

"""This example demonstrates how to use the Strands agent with Pipecat in a way where the agent explains its reasoning step-by-step.

You can delegate complex, multi-step tasks to the Strands agent, which can cycle through LLM-based reasoning and tool calls to accomplish the task.

Try asking: "What's the weather where the Golden Gate Bridge is?"
"""


# Strands agent tools


@tool
def get_location_name_from_landmark(landmark: str) -> str:
    """
    Get the location name from a landmark.

    Args:
        landmark (str): The name of the landmark, e.g. "Golden Gate Bridge".
    """
    # Simulate fetching location (slowly)
    time.sleep(3)
    return "San Francisco, CA"


@tool
def get_lat_long_from_location_name(location: str) -> dict:
    """
    Get the latitude and longitude for a location name.

    Args:
        location (str): The city and state, e.g. "San Francisco, CA".
    """
    # Simulate fetching lat/long from a geocoding service (slowly)
    time.sleep(3)
    return {"lat": 37.7749, "long": -122.4194}


@tool
def get_current_weather_from_lat_long(lat: float, long: float) -> dict:
    """
    Get the current weather for a specific latitude and longitude.

    Args:
        lat (float): The latitude of the location.
        long (float): The longitude of the location.
    """
    # Simulate fetching weather data from a weather service (slowly)
    time.sleep(3)
    return {"conditions": "nice", "temperature": "75"}


# We store functions so objects (e.g. SileroVADAnalyzer) don't get
# instantiated. The function will be called when the desired transport gets
# selected.
transport_params = {
    "daily": lambda: DailyParams(
        audio_in_enabled=True,
        audio_out_enabled=True,
        vad_analyzer=SileroVADAnalyzer(),
    ),
    "twilio": lambda: FastAPIWebsocketParams(
        audio_in_enabled=True,
        audio_out_enabled=True,
        vad_analyzer=SileroVADAnalyzer(),
    ),
    "webrtc": lambda: TransportParams(
        audio_in_enabled=True,
        audio_out_enabled=True,
        vad_analyzer=SileroVADAnalyzer(),
    ),
}


async def run_bot(transport: BaseTransport):
    logger.info(f"Starting bot")

    stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))

    tts = CartesiaTTSService(
        api_key=os.getenv("CARTESIA_API_KEY"),
        voice_id="71a7ad14-091c-4e8e-a314-022ece01c121",  # British Reading Lady
    )

    next_strands_message_is_last = False
    strands_messages_queue = asyncio.Queue()

    def strands_callback_handler(**kwargs):
        """
        Handle events from the Strands agent.
        """
        nonlocal next_strands_message_is_last
        if "event" in kwargs:
            event_obj = kwargs["event"]
            if event_obj and "messageStop" in event_obj:
                message_stop = event_obj["messageStop"]
                if message_stop and "stopReason" in message_stop:
                    stop_reason = message_stop["stopReason"]
                    if stop_reason == "end_turn":
                        next_strands_message_is_last = True
        elif "message" in kwargs:
            message_obj = kwargs["message"]
            if message_obj and "content" in message_obj and "role" in message_obj:
                role = message_obj["role"]
                content = message_obj["content"]
                if role == "assistant" and isinstance(content, list):
                    for content_obj in content:
                        if isinstance(content_obj, dict) and "text" in content_obj:
                            message = content_obj["text"]
                            if not next_strands_message_is_last:
                                strands_messages_queue.put_nowait(message)

    async def process_strands_messages():
        while True:
            message = await strands_messages_queue.get()
            await tts.queue_frame(TTSSpeakFrame(message))
            strands_messages_queue.task_done()

    asyncio.create_task(process_strands_messages())

    strands_agent = Agent(
        model=BedrockModel(
            model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0", max_tokens=64000
        ),
        tools=[
            get_location_name_from_landmark,
            get_lat_long_from_location_name,
            get_current_weather_from_lat_long,
        ],
        system_prompt="""
        You are a helpful personal assistant who can look up information about places and weather.

        Your key capabilities:
        1. Look up where landmarks are located.
        2. Find latitude and longitude for a location.
        3. Look up the current weather for a specific latitude and longitude.

        Explain each step of your reasoning in clear, simple, and concise language. Your responses will be converted to audio, so avoid special characters and numbered lists.
        """,
        callback_handler=strands_callback_handler,
    )

    llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))

    async def handle_location_or_weather_related_queries(params: FunctionCallParams, query: str):
        """
        Handle location or weather related queries.

        Args:
            query (str): The user's query, e.g. "What's the weather where the Golden Gate Bridge is?".
        """
        # Run in a background thread
        # (Otherwise the agent blocks the event loop; one effect of that is that we don't hear
        # the agent's "thinking" messages until the agent finishes)
        loop = asyncio.get_running_loop()
        result = await loop.run_in_executor(None, strands_agent, query)
        await params.result_callback(result.message)

    llm.register_direct_function(handle_location_or_weather_related_queries)

    @llm.event_handler("on_function_calls_started")
    async def on_function_calls_started(service, function_calls):
        await tts.queue_frame(TTSSpeakFrame("Let me check on that."))

    tools = ToolsSchema(standard_tools=[handle_location_or_weather_related_queries])

    messages = [
        {
            "role": "system",
            "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way. Start by suggesting that the user ask about the weather where the Golden Gate Bridge is.",
        },
    ]

    context = LLMContext(messages, tools)
    context_aggregator = LLMContextAggregatorPair(context)

    pipeline = Pipeline(
        [
            transport.input(),
            stt,
            context_aggregator.user(),
            llm,
            tts,
            transport.output(),
            context_aggregator.assistant(),
        ]
    )

    task = PipelineTask(
        pipeline,
        params=PipelineParams(
            enable_metrics=True,
            enable_usage_metrics=True,
        ),
    )

    @transport.event_handler("on_client_connected")
    async def on_client_connected(transport, client):
        logger.info(f"Client connected")
        # Kick off the conversation.
        await task.queue_frames([LLMRunFrame()])

    @transport.event_handler("on_client_disconnected")
    async def on_client_disconnected(transport, client):
        logger.info(f"Client disconnected")
        await task.cancel()

    runner = PipelineRunner(handle_sigint=False)

    await runner.run(task)


async def bot(runner_args: RunnerArguments):
    """Main bot entry point compatible with Pipecat Cloud."""
    transport = await create_transport(runner_args, transport_params)
    await run_bot(transport)


if __name__ == "__main__":
    from pipecat.runner.run import main

    main()
