From 08b5d10a267bc1411c8bbccec5ce9a7f30f5c9c3 Mon Sep 17 00:00:00 2001
From: Tom Hodson <thomas.hodson@ecmwf.int>
Date: Mon, 2 Dec 2024 09:49:45 +0000
Subject: [PATCH] Update tree compresser

---
 tree_compresser/Cargo.toml                    |   2 +-
 .../python_src/tree_traverser/__init__.py     |   3 +-
 tree_compresser/rust_src/lib.rs               | 112 +++++++++---------
 tree_compresser/rust_src/tree.rs              |  28 ++++-
 tree_compresser/tests/open_climate_dt.py      |  12 ++
 tree_compresser/tests/test.py                 |  69 ++++++++---
 6 files changed, 149 insertions(+), 77 deletions(-)
 create mode 100644 tree_compresser/tests/open_climate_dt.py

diff --git a/tree_compresser/Cargo.toml b/tree_compresser/Cargo.toml
index 9b1793c..8a5b366 100644
--- a/tree_compresser/Cargo.toml
+++ b/tree_compresser/Cargo.toml
@@ -7,7 +7,7 @@ edition = "2021"
 rsfdb = {git = "https://github.com/ecmwf/rsfdb", branch = "develop"}
 serde = { version = "1.0", features = ["derive"] }
 serde_json = "1.0"
-pyo3 = "0.23.1"
+pyo3 = "0.23"
 
 
 [lib]
diff --git a/tree_compresser/python_src/tree_traverser/__init__.py b/tree_compresser/python_src/tree_traverser/__init__.py
index d1138ce..f8c6565 100644
--- a/tree_compresser/python_src/tree_traverser/__init__.py
+++ b/tree_compresser/python_src/tree_traverser/__init__.py
@@ -1 +1,2 @@
-from . import rust as backend
\ No newline at end of file
+from . import rust as backend
+from .CompressedTree import CompressedTree
\ No newline at end of file
diff --git a/tree_compresser/rust_src/lib.rs b/tree_compresser/rust_src/lib.rs
index f74e242..9d9ceef 100644
--- a/tree_compresser/rust_src/lib.rs
+++ b/tree_compresser/rust_src/lib.rs
@@ -4,85 +4,79 @@
 
 use rsfdb::listiterator::KeyValueLevel;
 use rsfdb::request::Request;
-use rsfdb::FDB; // Make sure the `fdb` crate is correctly specified in the dependencies
+use rsfdb::FDB;
 
 use serde_json::{json, Value};
 use std::time::Instant;
 
 use pyo3::prelude::*;
-use pyo3::types::{PyDict, PyList, PyString};
+use pyo3::types::{PyDict, PyInt, PyList, PyString};
 
-use crate::tree::TreeNode;
 use std::collections::HashMap;
 
-/// Formats the sum of two numbers as string.
-#[pyfunction]
-#[pyo3(signature = (request, fdb_config = None))]
-fn traverse_fdb(
-    request: HashMap<String, Vec<String>>,
-    fdb_config: Option<&str>,
-) -> PyResult<String> {
-    let start_time = Instant::now();
-    let fdb = FDB::new(fdb_config).unwrap();
+pub mod tree;
+use std::sync::Arc;
+use std::sync::Mutex;
+use tree::TreeNode;
 
-    let list_request =
-        Request::from_json(json!(request)).expect("Failed to create request from python dict");
+#[pyclass(unsendable)]
+pub struct PyFDB {
+    pub fdb: FDB,
+}
 
-    let list = fdb.list(&list_request, true, true).unwrap();
-
-    // for item in list {
-    //     for kvl in item.request {
-    //         println!("{:?}", kvl);
-    //     }
-    // }
-
-    let mut root = TreeNode::new(KeyValueLevel {
-        key: "root".to_string(),
-        value: "root".to_string(),
-        level: 0,
-    });
-
-    for item in list {
-        if let Some(request) = &item.request {
-            root.insert(&request);
-        }
+#[pymethods]
+impl PyFDB {
+    #[new]
+    #[pyo3(signature = (fdb_config=None))]
+    pub fn new(fdb_config: Option<&str>) -> PyResult<Self> {
+        let fdb = FDB::new(fdb_config)
+            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
+        Ok(PyFDB { fdb })
     }
 
-    // Traverse and print the tree
-    root.traverse(0, &|node, level| {
-        let indent = "  ".repeat(level);
-        println!("{}{}={}", indent, node.key.key, node.key.value);
-    });
+    /// Traverse the FDB with the given request.
+    pub fn traverse_fdb(
+        &self,
+        py: Python<'_>,
+        request: HashMap<String, Vec<String>>,
+    ) -> PyResult<PyObject> {
+        let start_time = Instant::now();
 
-    // Convert the tree to JSON
-    // let json_output = root.to_json();
+        let list_request = Request::from_json(json!(request))
+            .map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string()))?;
 
-    // // Print the JSON output
-    // // println!("{}", serde_json::to_string_pretty(&json_output).unwrap());
-    // std::fs::write(
-    //     "output.json",
-    //     serde_json::to_string_pretty(&json_output).unwrap(),
-    // )
-    // .expect("Unable to write file");
+        // Use `fdb_guard` instead of `self.fdb`
+        let list = self
+            .fdb
+            .list(&list_request, true, true)
+            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
 
-    // let duration = start_time.elapsed();
-    // println!("Total runtime: {:?}", duration);
+        let mut root = TreeNode::new(KeyValueLevel {
+            key: "root".to_string(),
+            value: "root".to_string(),
+            level: 0,
+        });
 
-    Ok(("test").to_string())
+        for item in list {
+            py.check_signals()?;
+
+            if let Some(request) = &item.request {
+                root.insert(&request);
+            }
+        }
+
+        let duration = start_time.elapsed();
+        println!("Total runtime: {:?}", duration);
+
+        let py_dict = root.to_py_dict(py)?;
+        Ok(py_dict)
+    }
 }
 
 use pyo3::prelude::*;
 
-/// Formats the sum of two numbers as string.
-#[pyfunction]
-fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
-    Ok((a + b + 2).to_string())
-}
-
-/// A Python module implemented in Rust. The name of this function must match
-/// the `lib.name` setting in the `Cargo.toml`, else Python will not be able to
-/// import the module.
 #[pymodule]
 fn rust(m: &Bound<'_, PyModule>) -> PyResult<()> {
-    m.add_function(wrap_pyfunction!(traverse_fdb, m)?)
+    m.add_class::<PyFDB>()?;
+    Ok(())
 }
diff --git a/tree_compresser/rust_src/tree.rs b/tree_compresser/rust_src/tree.rs
index 9ab2488..c9ccd0d 100644
--- a/tree_compresser/rust_src/tree.rs
+++ b/tree_compresser/rust_src/tree.rs
@@ -1,7 +1,12 @@
+use pyo3::prelude::*;
+use pyo3::types::PyDict;
+use rsfdb::listiterator::KeyValueLevel;
+use serde_json::Value;
+
 #[derive(Debug)]
 pub struct TreeNode {
-    key: KeyValueLevel,
-    children: Vec<TreeNode>,
+    pub key: KeyValueLevel,
+    pub children: Vec<TreeNode>,
 }
 
 impl TreeNode {
@@ -63,4 +68,23 @@ impl TreeNode {
         // Combine the formatted key with children
         serde_json::json!({ formatted_key: children_json })
     }
+
+    pub fn to_py_dict(&self, py: Python) -> PyResult<PyObject> {
+        let py_dict = PyDict::new(py);
+
+        let formatted_key = format!("{}={}", self.key.key, self.key.value);
+
+        if self.children.is_empty() {
+            py_dict.set_item(formatted_key, PyDict::new(py))?;
+        } else {
+            let children_dict = PyDict::new(py);
+            for child in &self.children {
+                let child_key = format!("{}={}", child.key.key, child.key.value);
+                children_dict.set_item(child_key, child.to_py_dict(py)?)?;
+            }
+            py_dict.set_item(formatted_key, children_dict)?;
+        }
+
+        Ok(py_dict.to_object(py))
+    }
 }
diff --git a/tree_compresser/tests/open_climate_dt.py b/tree_compresser/tests/open_climate_dt.py
new file mode 100644
index 0000000..501e7dd
--- /dev/null
+++ b/tree_compresser/tests/open_climate_dt.py
@@ -0,0 +1,12 @@
+from tree_traverser import backend, CompressedTree
+from pathlib import Path
+
+data_path = Path("data/compressed_tree_climate_dt.json")
+# Print size of file
+print(f"climate dt compressed tree: {data_path.stat().st_size // 1e6:.1f} MB")
+
+print("Opening json file")
+compressed_tree = CompressedTree.load(data_path)
+
+print("Printing compressed tree")
+print(compressed_tree.reconstruct_compressed_ecmwf_style())
diff --git a/tree_compresser/tests/test.py b/tree_compresser/tests/test.py
index 63a693e..3fc0e63 100644
--- a/tree_compresser/tests/test.py
+++ b/tree_compresser/tests/test.py
@@ -1,26 +1,67 @@
-from tree_traverser import backend
+from tree_traverser import backend, CompressedTree
+import datetime
+import psutil
+from tqdm import tqdm
+from pathlib import Path
+import json
+from more_itertools import chunked
+process = psutil.Process()
+
+def massage_request(r):
+    return {k : v if isinstance(v, list) else [v]
+            for k, v in r.items()}
 
 
+if __name__ == "__main__":
 
-config = """
+    config = """
 ---
 type: remote
 host: databridge-prod-catalogue1-ope.ewctest.link
 port: 10000
 engine: remote
 store: remote
-"""
+    """
 
-def massage_request(r):
-    return {k : v if isinstance(v, list) else [v]
-            for k, v in r.items()}
+    request = {
+            "class": "d1",
+            "dataset": "climate-dt",
+            # "date": "19920420",
+        }
+    
+    data_path = Path("data/compressed_tree_climate_dt.json")
+    if not data_path.exists():
+        compressed_tree = CompressedTree({})
+    else:
+        compressed_tree = CompressedTree.load(data_path)
 
-request = {
-        "class": "d1",
-        "dataset": "extremes-dt",
-        "expver": "0001",
-        "stream": "oper",
-        "date": ["20241117", "20241116"],
-    }
+    fdb = backend.PyFDB(fdb_config = config)
 
-backend.traverse_fdb(massage_request(request), fdb_config = config)
+    visited_path = Path("data/visited_dates.json")
+    if not visited_path.exists():
+        visited_dates = set()
+    else:
+        with open(visited_path, "r") as f:
+            visited_dates = set(json.load(f))
+
+    today = datetime.datetime.today()
+    start = datetime.datetime.strptime("19920420", "%Y%m%d")
+    date_list = [start + datetime.timedelta(days=x) for x in range((today - start).days)]
+    date_list = [d.strftime("%Y%m%d") for d in date_list if d not in visited_dates] 
+    for dates in chunked(tqdm(date_list), 5):
+        print(dates[0])
+        print(f"Memory usage: {(process.memory_info().rss)/1e6:.1f} MB")
+
+        r = request | dict(date = dates)
+        tree = fdb.traverse_fdb(massage_request(r))
+
+        compressed_tree.insert_tree(tree)
+        compressed_tree.save(data_path)
+        
+        for date in dates:
+            visited_dates.add(date)
+
+        with open(visited_path, "w") as f:
+            json.dump(list(visited_dates), f)
+        
+        # print(compressed_tree.reconstruct_compressed_ecmwf_style())