import_jd_provider_tmp.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from pathlib import Path
  2. import pandas as pd
  3. import pymysql
  4. def get_conn():
  5. """
  6. 数据库连接配置(你自己补)。
  7. """
  8. return pymysql.connect(
  9. host="127.0.0.1",
  10. port=3306,
  11. user="root",
  12. password="your_password",
  13. database="your_database",
  14. charset="utf8mb4",
  15. autocommit=False,
  16. )
  17. def read_source_file(file_path: Path, sheet_name=0) -> pd.DataFrame:
  18. suffix = file_path.suffix.lower()
  19. if suffix in {".xlsx", ".xls"}:
  20. return pd.read_excel(file_path, sheet_name=sheet_name, dtype=object)
  21. if suffix == ".csv":
  22. return pd.read_csv(file_path, dtype=object, encoding="utf-8-sig")
  23. raise ValueError(f"不支持的文件类型: {suffix}")
  24. def get_table_columns(conn, table_name: str):
  25. sql = """
  26. SELECT COLUMN_NAME
  27. FROM information_schema.COLUMNS
  28. WHERE TABLE_SCHEMA = DATABASE()
  29. AND TABLE_NAME = %s
  30. ORDER BY ORDINAL_POSITION
  31. """
  32. with conn.cursor() as cur:
  33. cur.execute(sql, (table_name,))
  34. rows = cur.fetchall()
  35. return [r[0] for r in rows]
  36. def clean_df(df: pd.DataFrame) -> pd.DataFrame:
  37. # 去掉空列名和列名前后空格
  38. df = df.rename(columns=lambda x: str(x).strip() if x is not None else "")
  39. df = df[[c for c in df.columns if c]]
  40. # 把 NaN 转为 None,便于写库
  41. return df.where(pd.notna(df), None)
  42. def import_to_jd_provider_tmp(file_path: str, table_name="jd_provider_tmp", sheet_name=0, batch_size=500):
  43. file = Path(file_path)
  44. if not file.exists():
  45. raise FileNotFoundError(f"文件不存在: {file}")
  46. df = clean_df(read_source_file(file, sheet_name=sheet_name))
  47. if df.empty:
  48. print("源文件无数据,跳过导入。")
  49. return
  50. conn = get_conn()
  51. try:
  52. table_columns = get_table_columns(conn, table_name)
  53. if not table_columns:
  54. raise RuntimeError(f"未找到表或无字段: {table_name}")
  55. # 仅导入“表头和数据库字段同名”的列
  56. import_columns = [c for c in df.columns if c in table_columns]
  57. if not import_columns:
  58. raise RuntimeError("文件表头与数据库字段无交集,请检查表头是否与表字段同名。")
  59. missing_columns = [c for c in df.columns if c not in table_columns]
  60. if missing_columns:
  61. print(f"以下表头不在数据库表中,已自动忽略: {missing_columns}")
  62. insert_sql = (
  63. f"INSERT INTO `{table_name}` ({', '.join([f'`{c}`' for c in import_columns])}) "
  64. f"VALUES ({', '.join(['%s'] * len(import_columns))})"
  65. )
  66. values = [tuple(row[c] for c in import_columns) for _, row in df.iterrows()]
  67. total = len(values)
  68. inserted = 0
  69. with conn.cursor() as cur:
  70. for i in range(0, total, batch_size):
  71. batch = values[i:i + batch_size]
  72. cur.executemany(insert_sql, batch)
  73. inserted += len(batch)
  74. print(f"已导入: {inserted}/{total}")
  75. conn.commit()
  76. print(f"导入完成,表 `{table_name}` 共导入 {inserted} 条。")
  77. except Exception:
  78. conn.rollback()
  79. raise
  80. finally:
  81. conn.close()
  82. if __name__ == "__main__":
  83. fixed_file = Path.cwd() / "1111.xlsx"
  84. import_to_jd_provider_tmp(
  85. file_path=str(fixed_file),
  86. table_name="jd_provider_tmp",
  87. sheet_name=0,
  88. batch_size=500,
  89. )