import React, { useEffect, useRef } from "react";
import * as d3 from "d3";
import { Box } from "@mui/material";

interface DataPoint {
  x: number;
  y: number;
  label: string;
}

interface ScatterPlotProps {
  chartTitle: string;
  xLabel: string;
  yLabel: string;
  data: DataPoint[];
}

const ScatterPlot: React.FC<ScatterPlotProps> = ({
  chartTitle,
  xLabel,
  yLabel,
  data,
}) => {
  const svgRef = useRef<SVGSVGElement | null>(null);

  useEffect(() => {
    if (data.length === 0) return; // 데이터가 없으면 아무것도 렌더링하지 않음

    // Chart dimensions
    const margin = { top: 60, right: 40, bottom: 60, left: 60 };
    const width = 600 - margin.left - margin.right;
    const height = 400 - margin.top - margin.bottom;

    // Clear previous chart
    d3.select(svgRef.current).selectAll("*").remove();

    // Create SVG container
    const svg = d3
      .select(svgRef.current)
      .attr("width", width + margin.left + margin.right)
      .attr("height", height + margin.top + margin.bottom);

    // Create chart group
    const chart = svg
      .append("g")
      .attr("transform", `translate(${margin.left},${margin.top})`);

    // Compute min and max values for scales
    const xExtent = d3.extent(data, (d) => d.x) as [number, number];
    const yExtent = d3.extent(data, (d) => d.y) as [number, number];

    const xMin = xExtent[0] - (xExtent[1] - xExtent[0]) * 0.1;
    const xMax = xExtent[1] + (xExtent[1] - xExtent[0]) * 0.1;
    const yMin = yExtent[0] - (yExtent[1] - yExtent[0]) * 0.1;
    const yMax = yExtent[1] + (yExtent[1] - yExtent[0]) * 0.1;

    // Create scales
    const xScale = d3.scaleLinear().domain([xMin, xMax]).range([0, width]);
    const yScale = d3.scaleLinear().domain([yMin, yMax]).range([height, 0]);

    // Add X-axis
    chart
      .append("g")
      .attr("transform", `translate(0, ${height})`)
      .call(d3.axisBottom(xScale).tickSize(-height))
      .selectAll("text")
      .attr("font-size", "12px");

    // Add X-axis label
    chart
      .append("text")
      .attr("x", width / 2)
      .attr("y", height + margin.bottom - 10)
      .attr("text-anchor", "middle")
      .style("font-size", "14px")
      .style("font-weight", "bold")
      .text(xLabel);

    // Add Y-axis
    chart
      .append("g")
      .call(d3.axisLeft(yScale).tickSize(-width))
      .selectAll("text")
      .attr("font-size", "12px");

    // Add Y-axis label
    chart
      .append("text")
      .attr("transform", "rotate(-90)")
      .attr("x", -height / 2)
      .attr("y", -margin.left + 15)
      .attr("text-anchor", "middle")
      .style("font-size", "14px")
      .style("font-weight", "bold")
      .text(yLabel);

    // Add grid lines for better visibility
    chart
      .selectAll(".grid-line")
      .data(yScale.ticks())
      .enter()
      .append("line")
      .attr("x1", 0)
      .attr("x2", width)
      .attr("y1", (d) => yScale(d))
      .attr("y2", (d) => yScale(d))
      .attr("stroke", "#e0e0e0")
      .attr("stroke-dasharray", "4");

    chart
      .selectAll(".grid-line")
      .data(xScale.ticks())
      .enter()
      .append("line")
      .attr("y1", 0)
      .attr("y2", height)
      .attr("x1", (d) => xScale(d))
      .attr("x2", (d) => xScale(d))
      .attr("stroke", "#e0e0e0")
      .attr("stroke-dasharray", "4");

    // Add chart title
    svg
      .append("text")
      .attr("x", (width + margin.left + margin.right) / 2)
      .attr("y", margin.top / 2)
      .attr("text-anchor", "middle")
      .style("font-size", "18px")
      .style("font-weight", "bold")
      .text(chartTitle);

    // Add points
    chart
      .selectAll(".dot")
      .data(data)
      .enter()
      .append("circle")
      .attr("class", "dot")
      .attr("cx", (d) => xScale(d.x))
      .attr("cy", (d) => yScale(d.y))
      .attr("r", 6)
      .style("fill", "#007bff");

    // Add labels for points
    chart
      .selectAll(".label")
      .data(data)
      .enter()
      .append("text")
      .attr("x", (d) => xScale(d.x) + 8)
      .attr("y", (d) => yScale(d.y) + 4)
      .text((d) => d.label)
      .style("font-size", "12px")
      .style("fill", "#555");
  }, [chartTitle, xLabel, yLabel, data]);

  return (
    <Box sx={{ padding: "15px", borderRadius: "5px", background: "white" }}>
      <svg ref={svgRef}></svg>
    </Box>
  );
};

export default ScatterPlot;
