@ToolParserManager.register_module("hunyuan_a13b")
class HunyuanA13BToolParser(ToolParser):
    def __init__(self, tokenizer: AnyTokenizer):
        super().__init__(tokenizer)
        # Initialize state for streaming mode
        self.prev_tool_calls: list[dict] = []
        self.current_tool_id = -1
        self.current_tool_name_sent = False
        self.streamed_args: list[str] = [
        ]  # Track arguments sent for each tool
        # For backward compatibility with tests
        self.current_tools_sent: list[bool] = []
        # For backward compatibility with serving code
        self.prev_tool_call_arr = []
        # Regex patterns for preprocessing
        self.answer_tool_calls_pattern = re.compile(
            r"<tool_calls>([\s\S]*?)</tool_calls>", re.DOTALL)
        self.tool_name_reg = re.compile(r'"name"\s*:\s*"([^"]+)"')
        self.tool_empty_arg_reg = re.compile(
            r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*\{\s*\}')
        # TODO: not support nested json object in fc arguments.
        self.tool_non_empty_arg_reg = re.compile(
            r'"name"\s*:\s*"[^"]+"\s*,\s*"arguments"\s*:\s*(\{(?:[^{}]|(?:\{[^{}]*\}))*\})'
        )
        self.bot_string = "<tool_calls>"
        # Define streaming state type to be initialized later
        self.streaming_state: dict[str, Any] = {
            "current_tool_index": -1,
            "tool_ids": [],
            "sent_tools": [],
        }
    def preprocess_model_output(
            self, model_output: str) -> tuple[Optional[str], Optional[str]]:
        # find the location tool call
        for match in self.answer_tool_calls_pattern.finditer(model_output):
            start, end = match.span()
            # check tool_calls whether in side of <think>
            think_regions = [(m.start(), m.end()) for m in re.finditer(
                r"<think>(.*?)</think>", model_output, flags=re.DOTALL)]
            in_think = any(start > t_start and end < t_end
                           for t_start, t_end in think_regions)
            if not in_think:
                content = model_output[:start]
                tool_calls_content = match.group(1).strip()
                try:
                    json.loads(tool_calls_content)
                    return content, tool_calls_content
                except Exception:
                    continue
        return model_output, None
    def extract_tool_calls(
            self, model_output: str,
            request: ChatCompletionRequest) -> ExtractedToolCallInformation:
        """
        Extract tool calls from a complete model output.
        """
        try:
            # Preprocess the model output
            content, potential_tool_calls = self.preprocess_model_output(
                model_output)
            if not potential_tool_calls:
                # some text should be filtered out for no function call
                # this text is in a13b's chat template.
                if content:
                    content = content.replace("助手:", "", 1)
                return ExtractedToolCallInformation(tools_called=False,
                                                    tool_calls=[],
                                                    content=content)
            # Parse the potential tool calls as JSON
            tool_calls_data = json.loads(potential_tool_calls)
            # Ensure it's an array
            if not isinstance(tool_calls_data, list):
                logger.debug("Tool calls data is not an array")
                return ExtractedToolCallInformation(
                    tools_called=False,
                    tool_calls=[],
                    content=content or model_output,
                )
            tool_calls: list[ToolCall] = []
            for idx, call in enumerate(tool_calls_data):
                if (not isinstance(call, dict) or "name" not in call
                        or "arguments" not in call):
                    continue
                tool_call = ToolCall(
                    id=f"call_{random_uuid()}",
                    type="function",
                    function=FunctionCall(
                        name=call["name"],
                        arguments=(json.dumps(call["arguments"]) if isinstance(
                            call["arguments"], dict) else call["arguments"]),
                    ),
                )
                tool_calls.append(tool_call)
            if not content or len(content.strip()) == 0:
                # clear the whitespace content.
                content = None
            return ExtractedToolCallInformation(
                tools_called=len(tool_calls) > 0,
                tool_calls=tool_calls,
                content=content,
            )
        except Exception:
            return ExtractedToolCallInformation(tools_called=False,
                                                tool_calls=[],
                                                content=model_output)
    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:
        """
        Extract tool calls for streaming mode.
        """
        start_idx = consume_space(0, current_text)
        if current_text[start_idx:].startswith(self.bot_string):
            start_idx = consume_space(start_idx + len(self.bot_string),
                                      current_text)
        if not current_text or start_idx >= len(
                current_text) or current_text[start_idx] != '[':
            return DeltaMessage(content=delta_text)
        self._try_parse_json_tools(current_text[start_idx:])
        test_delta = self._handle_test_compatibility(current_text)
        if test_delta:
            return test_delta
        name_matches = list(self.tool_name_reg.finditer(current_text))
        tool_count = len(name_matches)
        if tool_count == 0:
            return None
        self._ensure_state_arrays(tool_count)
        current_idx = self.streaming_state["current_tool_index"]
        name_delta = self._handle_tool_name_streaming(current_idx, tool_count,
                                                      name_matches)
        if name_delta:
            return name_delta
        args_delta = self._handle_tool_args_streaming(current_text,
                                                      current_idx, tool_count)
        if args_delta:
            return args_delta
        return None
    def _try_parse_json_tools(self, current_text: str):
        try:
            parsed_tools = json.loads(current_text)
            if isinstance(parsed_tools, list):
                self.prev_tool_call_arr = parsed_tools
        except json.JSONDecodeError:
            pass
    def _handle_test_compatibility(self, current_text: str):
        if len(self.current_tools_sent) > 0:
            if (len(self.current_tools_sent) == 1
                    and self.current_tools_sent[0] is False):
                name_match = self.tool_name_reg.search(current_text)
                if name_match:
                    function_name = name_match.group(1)
                    tool_id = f"chatcmpl-tool-{random_uuid()}"
                    delta = DeltaMessage(tool_calls=[
                        DeltaToolCall(
                            index=0,
                            type="function",
                            id=tool_id,
                            function=DeltaFunctionCall(
                                name=function_name).model_dump(
                                    exclude_none=True),
                        )
                    ])
                    self.current_tools_sent = [True]
                    self.current_tool_id = 0
                    self.streaming_state["current_tool_index"] = 0
                    if len(self.streaming_state["sent_tools"]) == 0:
                        self.streaming_state["sent_tools"].append({
                            "sent_name":
                            True,
                            "sent_arguments_prefix":
                            False,
                            "sent_arguments":
                            "",
                        })
                    else:
                        self.streaming_state["sent_tools"][0][
                            "sent_name"] = True
                    self.current_tool_name_sent = True
                    return delta
        return None
    def _ensure_state_arrays(self, tool_count: int):
        while len(self.streaming_state["sent_tools"]) < tool_count:
            self.streaming_state["sent_tools"].append({
                "sent_name": False,
                "sent_arguments_prefix": False,
                "sent_arguments": "",
            })
        while len(self.streaming_state["tool_ids"]) < tool_count:
            self.streaming_state["tool_ids"].append(None)
    def _handle_tool_name_streaming(self, current_idx: int, tool_count: int,
                                    name_matches):
        if current_idx == -1 or current_idx < tool_count - 1:
            next_idx = current_idx + 1
            if (next_idx < tool_count
                    and not self.streaming_state["sent_tools"][next_idx]
                ["sent_name"]):
                self.streaming_state["current_tool_index"] = next_idx
                self.current_tool_id = next_idx
                current_idx = next_idx
                tool_name = name_matches[current_idx].group(1)
                tool_id = f"call_{current_idx}_{random_uuid()}"
                self.streaming_state["tool_ids"][current_idx] = tool_id
                delta = DeltaMessage(tool_calls=[
                    DeltaToolCall(
                        index=current_idx,
                        type="function",
                        id=tool_id,
                        function=DeltaFunctionCall(name=tool_name).model_dump(
                            exclude_none=True),
                    )
                ])
                self.streaming_state["sent_tools"][current_idx][
                    "sent_name"] = True
                self.current_tool_name_sent = True
                while len(self.streamed_args) <= current_idx:
                    self.streamed_args.append("")
                return delta
        return None
    def _handle_tool_args_streaming(self, current_text: str, current_idx: int,
                                    tool_count: int):
        if current_idx >= 0 and current_idx < tool_count:
            empty_args_match = self.tool_empty_arg_reg.search(current_text)
            if empty_args_match and empty_args_match.start() > 0:
                for i in range(tool_count):
                    if i == current_idx:
                        if not self.streaming_state["sent_tools"][current_idx][
                                "sent_arguments_prefix"]:
                            self.streaming_state["sent_tools"][current_idx][
                                "sent_arguments_prefix"] = True
                            self.streaming_state["sent_tools"][current_idx][
                                "sent_arguments"] = "{}"
                            while len(self.streamed_args) <= current_idx:
                                self.streamed_args.append("")
                            self.streamed_args[current_idx] += "{}"
                            delta = DeltaMessage(tool_calls=[
                                DeltaToolCall(
                                    index=current_idx,
                                    function=DeltaFunctionCall(
                                        arguments="{}").model_dump(
                                            exclude_none=True),
                                )
                            ])
                            if current_idx < tool_count - 1:
                                self.streaming_state["current_tool_index"] += 1
                                self.current_tool_id = self.streaming_state[
                                    "current_tool_index"]
                            return delta
            args_matches = list(
                self.tool_non_empty_arg_reg.finditer(current_text))
            if current_idx < len(args_matches):
                args_text = args_matches[current_idx].group(1)
                is_last_tool = current_idx == tool_count - 1
                if not is_last_tool:
                    next_tool_pos = current_text.find(
                        "},{", args_matches[current_idx].start())
                    if next_tool_pos != -1:
                        args_end_pos = (next_tool_pos + 1)
                        args_text = (
                            current_text[args_matches[current_idx].start(
                            ):args_end_pos].split('"arguments":')[1].strip())
                sent_args = self.streaming_state["sent_tools"][current_idx][
                    "sent_arguments"]
                if not self.streaming_state["sent_tools"][current_idx][
                        "sent_arguments_prefix"] and args_text.startswith("{"):
                    self.streaming_state["sent_tools"][current_idx][
                        "sent_arguments_prefix"] = True
                    self.streaming_state["sent_tools"][current_idx][
                        "sent_arguments"] = "{"
                    while len(self.streamed_args) <= current_idx:
                        self.streamed_args.append("")
                    self.streamed_args[current_idx] += "{"
                    delta = DeltaMessage(tool_calls=[
                        DeltaToolCall(
                            index=current_idx,
                            function=DeltaFunctionCall(
                                arguments="{").model_dump(exclude_none=True),
                        )
                    ])
                    return delta
                if args_text.startswith(sent_args):
                    args_diff = args_text[len(sent_args):]
                    if args_diff:
                        self.streaming_state["sent_tools"][current_idx][
                            "sent_arguments"] = args_text
                        while len(self.streamed_args) <= current_idx:
                            self.streamed_args.append("")
                        self.streamed_args[current_idx] += args_diff
                        delta = DeltaMessage(tool_calls=[
                            DeltaToolCall(
                                index=current_idx,
                                function=DeltaFunctionCall(
                                    arguments=args_diff).model_dump(
                                        exclude_none=True),
                            )
                        ])
                        return delta
                if args_text.endswith("}") and args_text == sent_args:
                    if current_idx < tool_count - 1:
                        self.streaming_state["current_tool_index"] += 1
                        self.current_tool_id = self.streaming_state[
                            "current_tool_index"]
        return None