From 2a2af8487eb8dbeabdac9cc8763df0254448d953 Mon Sep 17 00:00:00 2001 From: zhaozhn_sugon Date: Fri, 29 May 2026 16:30:20 +0800 Subject: [PATCH 1/5] feat: simplify the data path in earth model. --- examples/earth/fengwu/conf/config.yaml | 2 - examples/earth/fengwu/fake_data.py | 35 +++++++------ examples/earth/fourcastnet/conf/config.yaml | 2 - examples/earth/fourcastnet/fake_data.py | 24 +++++---- examples/earth/fuxi/conf/config.yaml | 2 - examples/earth/fuxi/fake_data.py | 8 +-- .../earth/graphcast/compute_time_diff_std.py | 2 - examples/earth/graphcast/conf/config.yaml | 2 - examples/earth/graphcast/fake_data.py | 50 +++++++------------ examples/earth/pangu_weather/conf/config.yaml | 2 - examples/earth/pangu_weather/fake_data.py | 24 +++++---- examples/earth/pangu_weather/train.py | 8 +-- 12 files changed, 74 insertions(+), 87 deletions(-) diff --git a/examples/earth/fengwu/conf/config.yaml b/examples/earth/fengwu/conf/config.yaml index 259a79f1..c1260f23 100755 --- a/examples/earth/fengwu/conf/config.yaml +++ b/examples/earth/fengwu/conf/config.yaml @@ -20,8 +20,6 @@ datapipe: # dataset设定 dataset: type: "hdf5" - stats_dir: './data/stats/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/stats/" - static_dir: './data/static/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/static/" data_dir: './data/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/" train_time: [1951, 1952] diff --git a/examples/earth/fengwu/fake_data.py b/examples/earth/fengwu/fake_data.py index 00449ec9..fe565878 100644 --- a/examples/earth/fengwu/fake_data.py +++ b/examples/earth/fengwu/fake_data.py @@ -23,8 +23,8 @@ def generate_fake_h5(data_dir, var_names, years, dims): path = os.path.join(data_dir, "data", f"{year}.h5") with h5py.File(path, "w") as f: ds = f.create_dataset( - "fields", - shape=(T, C, H, W), + "fields", + shape=(T, C, H, W), dtype="float32", chunks=(1, C, H, W), fillvalue=0.0, @@ -49,8 +49,8 @@ def generate_stats(data_dir, n_vars): -def get_static(cfg, var, name): - os.makedirs(cfg.static_dir, exist_ok=True) +def get_static(data_dir, var, name): + os.makedirs(data_dir, exist_ok=True) ds = xr.Dataset( data_vars={ f"{var}": (("valid_time", "latitude", "longitude"), @@ -73,17 +73,18 @@ def get_static(cfg, var, name): } ) - ds.to_netcdf(f"{cfg.static_dir}/{name}.nc") - arr = np.random.randn(721, 1440).astype(np.float32) # 保存数据 - np.save(f'{cfg.static_dir}/land_mask.npy', arr) - np.save(f'{cfg.static_dir}/soil_type.npy', arr) - np.save(f'{cfg.static_dir}/topography.npy', arr) - print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {cfg.static_dir}") - + ds.to_netcdf(f"{data_dir}/{name}.nc") + arr = np.random.randn(721, 1440).astype(np.float32) + np.save(f'{data_dir}/land_mask.npy', arr) + np.save(f'{data_dir}/soil_type.npy', arr) + np.save(f'{data_dir}/topography.npy', arr) + print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {data_dir}") + + if __name__ == "__main__": cfg_datapipe = YParams("conf/config.yaml", "datapipe") - if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work/"): + if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work2/"): print("请检查 config,确保各 *_dir 指向本地测试路径而非生产路径。") exit() @@ -91,10 +92,12 @@ if __name__ == "__main__": atm_vars = cfg_datapipe.dataset.channels generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - generate_stats(cfg_datapipe.dataset.stats_dir, len(atm_vars)) + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir "stats") + generate_stats(stats_dir, len(atm_vars)) + + static_dir = os.path.join(cfg_datapipe.dataset.data_dir "static") + get_static(static_dir, 'z', 'geopotential') + get_static(static_dir, 'lsm', 'land_sea_mask') - get_static(cfg_datapipe.dataset, 'z', 'geopotential') - get_static(cfg_datapipe.dataset, 'lsm', 'land_sea_mask') - print("\n✅ Fake datasets generated.") diff --git a/examples/earth/fourcastnet/conf/config.yaml b/examples/earth/fourcastnet/conf/config.yaml index ca8f9c49..d754a09a 100755 --- a/examples/earth/fourcastnet/conf/config.yaml +++ b/examples/earth/fourcastnet/conf/config.yaml @@ -16,8 +16,6 @@ datapipe: # dataset设定 dataset: type: "hdf5" - stats_dir: './data/stats/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/stats/" - static_dir: './data/static/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/static/" data_dir: './data/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/" train_time: [1951, 1952] diff --git a/examples/earth/fourcastnet/fake_data.py b/examples/earth/fourcastnet/fake_data.py index c9eb954b..fe565878 100644 --- a/examples/earth/fourcastnet/fake_data.py +++ b/examples/earth/fourcastnet/fake_data.py @@ -49,8 +49,8 @@ def generate_stats(data_dir, n_vars): -def get_static(cfg, var, name): - os.makedirs(cfg.static_dir, exist_ok=True) +def get_static(data_dir, var, name): + os.makedirs(data_dir, exist_ok=True) ds = xr.Dataset( data_vars={ f"{var}": (("valid_time", "latitude", "longitude"), @@ -73,18 +73,18 @@ def get_static(cfg, var, name): } ) - ds.to_netcdf(f"{cfg.static_dir}/{name}.nc") + ds.to_netcdf(f"{data_dir}/{name}.nc") arr = np.random.randn(721, 1440).astype(np.float32) - np.save(f'{cfg.static_dir}/land_mask.npy', arr) - np.save(f'{cfg.static_dir}/soil_type.npy', arr) - np.save(f'{cfg.static_dir}/topography.npy', arr) - print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {cfg.static_dir}") + np.save(f'{data_dir}/land_mask.npy', arr) + np.save(f'{data_dir}/soil_type.npy', arr) + np.save(f'{data_dir}/topography.npy', arr) + print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {data_dir}") if __name__ == "__main__": cfg_datapipe = YParams("conf/config.yaml", "datapipe") - if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work/"): + if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work2/"): print("请检查 config,确保各 *_dir 指向本地测试路径而非生产路径。") exit() @@ -92,10 +92,12 @@ if __name__ == "__main__": atm_vars = cfg_datapipe.dataset.channels generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - generate_stats(cfg_datapipe.dataset.stats_dir, len(atm_vars)) + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir "stats") + generate_stats(stats_dir, len(atm_vars)) - get_static(cfg_datapipe.dataset, 'z', 'geopotential') - get_static(cfg_datapipe.dataset, 'lsm', 'land_sea_mask') + static_dir = os.path.join(cfg_datapipe.dataset.data_dir "static") + get_static(static_dir, 'z', 'geopotential') + get_static(static_dir, 'lsm', 'land_sea_mask') print("\n✅ Fake datasets generated.") diff --git a/examples/earth/fuxi/conf/config.yaml b/examples/earth/fuxi/conf/config.yaml index 612f1bf0..38aa0f93 100755 --- a/examples/earth/fuxi/conf/config.yaml +++ b/examples/earth/fuxi/conf/config.yaml @@ -27,8 +27,6 @@ datapipe: # dataset设定 dataset: type: "hdf5" - stats_dir: './data/stats/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/stats/" - static_dir: './data/static/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/static/" data_dir: './data/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/" train_time: [1951, 1952] val_time: [1953] diff --git a/examples/earth/fuxi/fake_data.py b/examples/earth/fuxi/fake_data.py index eb0770aa..bc6ac007 100644 --- a/examples/earth/fuxi/fake_data.py +++ b/examples/earth/fuxi/fake_data.py @@ -137,13 +137,15 @@ if __name__ == "__main__": # 主 ERA5 数据 generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - generate_stats(cfg_datapipe.dataset.stats_dir, n_vars) + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir "stats") + generate_stats(stats_dir, n_vars) # 各阶段中间模型输出(用于 short/medium/long 微调的输入) for stage in ['base', 'short', 'medium', 'long']: generate_fake_npy(f'./result/{stage}', n_vars, years, DATASET_DIMS) - get_static(cfg_datapipe.dataset, 'z', 'geopotential') - get_static(cfg_datapipe.dataset, 'lsm', 'land_sea_mask') + static_dir = os.path.join(cfg_datapipe.dataset.data_dir "static") + get_static(static_dir, 'z', 'geopotential') + get_static(static_dir, 'lsm', 'land_sea_mask') print("\n✅ Fake datasets generated.") diff --git a/examples/earth/graphcast/compute_time_diff_std.py b/examples/earth/graphcast/compute_time_diff_std.py index b7f52ef1..5328cc8a 100755 --- a/examples/earth/graphcast/compute_time_diff_std.py +++ b/examples/earth/graphcast/compute_time_diff_std.py @@ -26,7 +26,6 @@ def main(): area = area.unsqueeze(1) mean, mean_sqr = 0, 0 - k = 0 for data in tqdm(train_dataloader): invar = data[0] # [b, N, h, w] outvar = data[1] # [b, N, h, w] @@ -35,7 +34,6 @@ def main(): weighted_diff_sqr = torch.square(weighted_diff) mean += torch.mean(weighted_diff, dim=(2, 3)) / len(train_dataloader) mean_sqr += torch.mean(weighted_diff_sqr, dim=(2, 3)) / len(train_dataloader) - k += 1 variance = mean_sqr - mean**2 # [1,num_channel, 1,1] std = torch.sqrt(variance) diff --git a/examples/earth/graphcast/conf/config.yaml b/examples/earth/graphcast/conf/config.yaml index 4d9fd444..1efe2e35 100755 --- a/examples/earth/graphcast/conf/config.yaml +++ b/examples/earth/graphcast/conf/config.yaml @@ -52,8 +52,6 @@ datapipe: # dataset设定 dataset: type: "hdf5" - stats_dir: './data/stats/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/stats/" - static_dir: './data/static/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/static/" data_dir: './data/' # "$ONESCIENCE_DATASETS_DIR/ERA5/newh5/" dataset_metadata_path: './data.json' # Path to the dataset metadata, containing channel names. time_diff_std_path: './time_diff_std.npy' # Path to the .npy file with standard deviation of normalized per-variable per-pressure level time differences. diff --git a/examples/earth/graphcast/fake_data.py b/examples/earth/graphcast/fake_data.py index 1f0737f7..fe565878 100644 --- a/examples/earth/graphcast/fake_data.py +++ b/examples/earth/graphcast/fake_data.py @@ -6,46 +6,32 @@ from onescience.utils.YParams import YParams # 各数据集固定的空间和时间维度 -DATASET_DIMS = {"T": 50, "H": 721, "W": 1440, "time_step": 6} -PATCH_HW = 32 -PATCHES_PER_STEP = 2 -RNG_SEED = 42 +DATASET_DIMS = {"T": 10, "H": 721, "W": 1440, "time_step": 6} def generate_fake_h5(data_dir, var_names, years, dims): """ - 为每个年份生成一个稀疏随机 h5 文件。 - 仅写入少量随机 patch,其余 chunk 仍保持 fill_value=0, - 这样既能让 time-diff std 非零,也能保持文件很小。 + 为每个年份生成一个空 h5 文件。 + 利用 HDF5 chunked 压缩数据集未写入 chunk 即返回 fill_value=0 的特性, + 文件实际只含元数据,极小,但 shape 与真实数据完全一致。 """ os.makedirs(os.path.join(data_dir, "data"), exist_ok=True) T, C = dims["T"], len(var_names) H, W = dims["H"], dims["W"] - patch_h = min(PATCH_HW, H) - patch_w = min(PATCH_HW, W) - for year_idx, year in enumerate(years): + for year in years: path = os.path.join(data_dir, "data", f"{year}.h5") - rng = np.random.default_rng(RNG_SEED + year_idx) with h5py.File(path, "w") as f: ds = f.create_dataset( "fields", shape=(T, C, H, W), dtype="float32", - chunks=(1, C, patch_h, patch_w), + chunks=(1, C, H, W), fillvalue=0.0, ) ds.attrs["variables"] = var_names ds.attrs["time_step"] = dims["time_step"] - for t in range(T): - for _ in range(PATCHES_PER_STEP): - top = int(rng.integers(0, H - patch_h + 1)) - left = int(rng.integers(0, W - patch_w + 1)) - patch = rng.standard_normal((C, patch_h, patch_w)).astype(np.float32) - patch += np.float32(0.1 * t) - ds[t, :, top:top + patch_h, left:left + patch_w] = patch - size_kb = os.path.getsize(path) / 1024 print(f" {year}.h5 shape=({T},{C},{H},{W}) " f"logical={T*C*H*W*4/1024**3:.1f}GB actual={size_kb:.1f}KB") @@ -63,8 +49,8 @@ def generate_stats(data_dir, n_vars): -def get_static(cfg, var, name): - os.makedirs(cfg.static_dir, exist_ok=True) +def get_static(data_dir, var, name): + os.makedirs(data_dir, exist_ok=True) ds = xr.Dataset( data_vars={ f"{var}": (("valid_time", "latitude", "longitude"), @@ -87,18 +73,18 @@ def get_static(cfg, var, name): } ) - ds.to_netcdf(f"{cfg.static_dir}/{name}.nc") + ds.to_netcdf(f"{data_dir}/{name}.nc") arr = np.random.randn(721, 1440).astype(np.float32) - np.save(f'{cfg.static_dir}/land_mask.npy', arr) - np.save(f'{cfg.static_dir}/soil_type.npy', arr) - np.save(f'{cfg.static_dir}/topography.npy', arr) - print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {cfg.static_dir}") + np.save(f'{data_dir}/land_mask.npy', arr) + np.save(f'{data_dir}/soil_type.npy', arr) + np.save(f'{data_dir}/topography.npy', arr) + print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {data_dir}") if __name__ == "__main__": cfg_datapipe = YParams("conf/config.yaml", "datapipe") - if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work/"): + if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work2/"): print("请检查 config,确保各 *_dir 指向本地测试路径而非生产路径。") exit() @@ -106,10 +92,12 @@ if __name__ == "__main__": atm_vars = cfg_datapipe.dataset.channels generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - generate_stats(cfg_datapipe.dataset.stats_dir, len(atm_vars)) + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir "stats") + generate_stats(stats_dir, len(atm_vars)) - get_static(cfg_datapipe.dataset, 'z', 'geopotential') - get_static(cfg_datapipe.dataset, 'lsm', 'land_sea_mask') + static_dir = os.path.join(cfg_datapipe.dataset.data_dir "static") + get_static(static_dir, 'z', 'geopotential') + get_static(static_dir, 'lsm', 'land_sea_mask') print("\n✅ Fake datasets generated.") diff --git a/examples/earth/pangu_weather/conf/config.yaml b/examples/earth/pangu_weather/conf/config.yaml index de1c893a..8339c796 100755 --- a/examples/earth/pangu_weather/conf/config.yaml +++ b/examples/earth/pangu_weather/conf/config.yaml @@ -16,8 +16,6 @@ datapipe: # dataset设定 dataset: type: "hdf5" - stats_dir: './data/stats/' - static_dir: './data/static/' # data_dir: './data/' train_time: [1951, 1952] diff --git a/examples/earth/pangu_weather/fake_data.py b/examples/earth/pangu_weather/fake_data.py index c9eb954b..fe565878 100644 --- a/examples/earth/pangu_weather/fake_data.py +++ b/examples/earth/pangu_weather/fake_data.py @@ -49,8 +49,8 @@ def generate_stats(data_dir, n_vars): -def get_static(cfg, var, name): - os.makedirs(cfg.static_dir, exist_ok=True) +def get_static(data_dir, var, name): + os.makedirs(data_dir, exist_ok=True) ds = xr.Dataset( data_vars={ f"{var}": (("valid_time", "latitude", "longitude"), @@ -73,18 +73,18 @@ def get_static(cfg, var, name): } ) - ds.to_netcdf(f"{cfg.static_dir}/{name}.nc") + ds.to_netcdf(f"{data_dir}/{name}.nc") arr = np.random.randn(721, 1440).astype(np.float32) - np.save(f'{cfg.static_dir}/land_mask.npy', arr) - np.save(f'{cfg.static_dir}/soil_type.npy', arr) - np.save(f'{cfg.static_dir}/topography.npy', arr) - print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {cfg.static_dir}") + np.save(f'{data_dir}/land_mask.npy', arr) + np.save(f'{data_dir}/soil_type.npy', arr) + np.save(f'{data_dir}/topography.npy', arr) + print(f"✅ Static data: {arr.shape}, dtype: {arr.dtype}, save to {data_dir}") if __name__ == "__main__": cfg_datapipe = YParams("conf/config.yaml", "datapipe") - if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work/"): + if cfg_datapipe.dataset.data_dir.startswith("/public/") or cfg_datapipe.dataset.data_dir.startswith("/work2/"): print("请检查 config,确保各 *_dir 指向本地测试路径而非生产路径。") exit() @@ -92,10 +92,12 @@ if __name__ == "__main__": atm_vars = cfg_datapipe.dataset.channels generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - generate_stats(cfg_datapipe.dataset.stats_dir, len(atm_vars)) + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir "stats") + generate_stats(stats_dir, len(atm_vars)) - get_static(cfg_datapipe.dataset, 'z', 'geopotential') - get_static(cfg_datapipe.dataset, 'lsm', 'land_sea_mask') + static_dir = os.path.join(cfg_datapipe.dataset.data_dir "static") + get_static(static_dir, 'z', 'geopotential') + get_static(static_dir, 'lsm', 'land_sea_mask') print("\n✅ Fake datasets generated.") diff --git a/examples/earth/pangu_weather/train.py b/examples/earth/pangu_weather/train.py index 6badf83d..5b177be8 100755 --- a/examples/earth/pangu_weather/train.py +++ b/examples/earth/pangu_weather/train.py @@ -57,9 +57,11 @@ def main(): surface_weights = torch.as_tensor(cfg_data.dataset.weights[:4], device=local_rank, dtype=torch.float32).view(1, -1, 1, 1) pressure_weights = torch.as_tensor(cfg_data.dataset.weights[4:], device=local_rank, dtype=torch.float32).view(1, -1, 1, 1) - land_mask = torch.from_numpy(np.load(os.path.join(cfg_data.dataset.static_dir, "land_mask.npy")).astype(np.float32)) - soil_type = torch.from_numpy(np.load(os.path.join(cfg_data.dataset.static_dir, "soil_type.npy")).astype(np.float32)) - topography = torch.from_numpy(np.load(os.path.join(cfg_data.dataset.static_dir, "topography.npy")).astype(np.float32)) + static_dir = os.path.join(cfg_datapipe.dataset.data_dir "static") + + land_mask = torch.from_numpy(np.load(os.path.join(static_dir, "land_mask.npy")).astype(np.float32)) + soil_type = torch.from_numpy(np.load(os.path.join(static_dir, "soil_type.npy")).astype(np.float32)) + topography = torch.from_numpy(np.load(os.path.join(static_dir, "topography.npy")).astype(np.float32)) topography = (topography - topography.mean()) / (topography.std(unbiased=False) + 1e-6) surface_mask = torch.stack([land_mask, soil_type, topography], dim=0).to(local_rank) surface_mask = surface_mask.unsqueeze(0).repeat(cfg_data.dataloader.batch_size, 1, 1, 1) -- Gitee From 46c0f7e25f670087009c6f40dc0f18aa685a37c8 Mon Sep 17 00:00:00 2001 From: zhaozhn_sugon Date: Fri, 29 May 2026 17:10:47 +0800 Subject: [PATCH 2/5] fix: add comma --- examples/earth/fengwu/fake_data.py | 2 +- examples/earth/fourcastnet/fake_data.py | 2 +- examples/earth/fuxi/fake_data.py | 4 ++-- examples/earth/pangu_weather/fake_data.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/earth/fengwu/fake_data.py b/examples/earth/fengwu/fake_data.py index fe565878..fa7762c5 100644 --- a/examples/earth/fengwu/fake_data.py +++ b/examples/earth/fengwu/fake_data.py @@ -95,7 +95,7 @@ if __name__ == "__main__": stats_dir = os.path.join(cfg_datapipe.dataset.data_dir "stats") generate_stats(stats_dir, len(atm_vars)) - static_dir = os.path.join(cfg_datapipe.dataset.data_dir "static") + static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") get_static(static_dir, 'z', 'geopotential') get_static(static_dir, 'lsm', 'land_sea_mask') diff --git a/examples/earth/fourcastnet/fake_data.py b/examples/earth/fourcastnet/fake_data.py index fe565878..fa7762c5 100644 --- a/examples/earth/fourcastnet/fake_data.py +++ b/examples/earth/fourcastnet/fake_data.py @@ -95,7 +95,7 @@ if __name__ == "__main__": stats_dir = os.path.join(cfg_datapipe.dataset.data_dir "stats") generate_stats(stats_dir, len(atm_vars)) - static_dir = os.path.join(cfg_datapipe.dataset.data_dir "static") + static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") get_static(static_dir, 'z', 'geopotential') get_static(static_dir, 'lsm', 'land_sea_mask') diff --git a/examples/earth/fuxi/fake_data.py b/examples/earth/fuxi/fake_data.py index bc6ac007..b4bf55e6 100644 --- a/examples/earth/fuxi/fake_data.py +++ b/examples/earth/fuxi/fake_data.py @@ -137,14 +137,14 @@ if __name__ == "__main__": # 主 ERA5 数据 generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - stats_dir = os.path.join(cfg_datapipe.dataset.data_dir "stats") + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir, "stats") generate_stats(stats_dir, n_vars) # 各阶段中间模型输出(用于 short/medium/long 微调的输入) for stage in ['base', 'short', 'medium', 'long']: generate_fake_npy(f'./result/{stage}', n_vars, years, DATASET_DIMS) - static_dir = os.path.join(cfg_datapipe.dataset.data_dir "static") + static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") get_static(static_dir, 'z', 'geopotential') get_static(static_dir, 'lsm', 'land_sea_mask') diff --git a/examples/earth/pangu_weather/fake_data.py b/examples/earth/pangu_weather/fake_data.py index fe565878..fa7762c5 100644 --- a/examples/earth/pangu_weather/fake_data.py +++ b/examples/earth/pangu_weather/fake_data.py @@ -95,7 +95,7 @@ if __name__ == "__main__": stats_dir = os.path.join(cfg_datapipe.dataset.data_dir "stats") generate_stats(stats_dir, len(atm_vars)) - static_dir = os.path.join(cfg_datapipe.dataset.data_dir "static") + static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") get_static(static_dir, 'z', 'geopotential') get_static(static_dir, 'lsm', 'land_sea_mask') -- Gitee From d7e6463fd9e6defb876a77864c48567df9a05305 Mon Sep 17 00:00:00 2001 From: zhaozhn_sugon Date: Fri, 29 May 2026 17:11:58 +0800 Subject: [PATCH 3/5] fix: add comma --- examples/earth/fengwu/fake_data.py | 2 +- examples/earth/fourcastnet/fake_data.py | 2 +- examples/earth/graphcast/fake_data.py | 4 ++-- examples/earth/pangu_weather/fake_data.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/earth/fengwu/fake_data.py b/examples/earth/fengwu/fake_data.py index fa7762c5..2923e980 100644 --- a/examples/earth/fengwu/fake_data.py +++ b/examples/earth/fengwu/fake_data.py @@ -92,7 +92,7 @@ if __name__ == "__main__": atm_vars = cfg_datapipe.dataset.channels generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - stats_dir = os.path.join(cfg_datapipe.dataset.data_dir "stats") + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir, "stats") generate_stats(stats_dir, len(atm_vars)) static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") diff --git a/examples/earth/fourcastnet/fake_data.py b/examples/earth/fourcastnet/fake_data.py index fa7762c5..2923e980 100644 --- a/examples/earth/fourcastnet/fake_data.py +++ b/examples/earth/fourcastnet/fake_data.py @@ -92,7 +92,7 @@ if __name__ == "__main__": atm_vars = cfg_datapipe.dataset.channels generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - stats_dir = os.path.join(cfg_datapipe.dataset.data_dir "stats") + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir, "stats") generate_stats(stats_dir, len(atm_vars)) static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") diff --git a/examples/earth/graphcast/fake_data.py b/examples/earth/graphcast/fake_data.py index fe565878..2923e980 100644 --- a/examples/earth/graphcast/fake_data.py +++ b/examples/earth/graphcast/fake_data.py @@ -92,10 +92,10 @@ if __name__ == "__main__": atm_vars = cfg_datapipe.dataset.channels generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - stats_dir = os.path.join(cfg_datapipe.dataset.data_dir "stats") + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir, "stats") generate_stats(stats_dir, len(atm_vars)) - static_dir = os.path.join(cfg_datapipe.dataset.data_dir "static") + static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") get_static(static_dir, 'z', 'geopotential') get_static(static_dir, 'lsm', 'land_sea_mask') diff --git a/examples/earth/pangu_weather/fake_data.py b/examples/earth/pangu_weather/fake_data.py index fa7762c5..2923e980 100644 --- a/examples/earth/pangu_weather/fake_data.py +++ b/examples/earth/pangu_weather/fake_data.py @@ -92,7 +92,7 @@ if __name__ == "__main__": atm_vars = cfg_datapipe.dataset.channels generate_fake_h5(cfg_datapipe.dataset.data_dir, atm_vars, years, DATASET_DIMS) - stats_dir = os.path.join(cfg_datapipe.dataset.data_dir "stats") + stats_dir = os.path.join(cfg_datapipe.dataset.data_dir, "stats") generate_stats(stats_dir, len(atm_vars)) static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") -- Gitee From 119d97ba2017a7a78734c73f54b13c80fff3e882 Mon Sep 17 00:00:00 2001 From: zhaozhn_sugon Date: Fri, 29 May 2026 17:14:23 +0800 Subject: [PATCH 4/5] fix: add comma --- examples/earth/pangu_weather/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/earth/pangu_weather/train.py b/examples/earth/pangu_weather/train.py index 5b177be8..936b3e86 100755 --- a/examples/earth/pangu_weather/train.py +++ b/examples/earth/pangu_weather/train.py @@ -57,7 +57,7 @@ def main(): surface_weights = torch.as_tensor(cfg_data.dataset.weights[:4], device=local_rank, dtype=torch.float32).view(1, -1, 1, 1) pressure_weights = torch.as_tensor(cfg_data.dataset.weights[4:], device=local_rank, dtype=torch.float32).view(1, -1, 1, 1) - static_dir = os.path.join(cfg_datapipe.dataset.data_dir "static") + static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") land_mask = torch.from_numpy(np.load(os.path.join(static_dir, "land_mask.npy")).astype(np.float32)) soil_type = torch.from_numpy(np.load(os.path.join(static_dir, "soil_type.npy")).astype(np.float32)) -- Gitee From cbd5bf026408447c2e733fab3feeadde01346a66 Mon Sep 17 00:00:00 2001 From: zhaozhn_sugon Date: Fri, 29 May 2026 17:15:01 +0800 Subject: [PATCH 5/5] fix: add comma --- examples/earth/pangu_weather/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/earth/pangu_weather/train.py b/examples/earth/pangu_weather/train.py index 936b3e86..db443863 100755 --- a/examples/earth/pangu_weather/train.py +++ b/examples/earth/pangu_weather/train.py @@ -57,7 +57,7 @@ def main(): surface_weights = torch.as_tensor(cfg_data.dataset.weights[:4], device=local_rank, dtype=torch.float32).view(1, -1, 1, 1) pressure_weights = torch.as_tensor(cfg_data.dataset.weights[4:], device=local_rank, dtype=torch.float32).view(1, -1, 1, 1) - static_dir = os.path.join(cfg_datapipe.dataset.data_dir, "static") + static_dir = os.path.join(cfg_data.dataset.data_dir, "static") land_mask = torch.from_numpy(np.load(os.path.join(static_dir, "land_mask.npy")).astype(np.float32)) soil_type = torch.from_numpy(np.load(os.path.join(static_dir, "soil_type.npy")).astype(np.float32)) -- Gitee