Back to Models
basic_model
v1.0.0
PythonBasic model demonstrating Spark, Elasticsearch, and local CSV data adapters. Supports both v1 (execute) and v2 (train/infer) interfaces for backward compatibility.
$ openuba install basic_model OpenUBA
python-base
License: Apache-2.0
data-adaptersparkelasticsearchcsvmulti-sourcev1-compatible
Parameters
| Name | Type | Default | Description |
|---|---|---|---|
| threshold | integer | 10000 | Row count threshold for anomaly detection |
| data_source | string | spark | Data source type(spark, elasticsearch, local_csv) |
model.yaml
1name: basic_model
2version: 1.0.0
3runtime: python-base
4description: Basic model demonstrating Spark/ES adapters
5parameters:
6 threshold:
7 type: integer
8 default: 10000
9 description: Row count threshold for anomaly detection
10 data_source:
11 type: string
12 default: spark
13 description: Data source type (spark, elasticsearch, local_csv)
14 enum: [spark, elasticsearch, local_csv]
15MODEL.py
1'''
2Copyright 2019-Present The OpenUBA Platform Authors
3basic model template using data adapters
4this model demonstrates how to use spark and elasticsearch data adapters
5supports both v1 (execute) and v2 (train/infer) interfaces
6'''
7
8import os
9import logging
10from typing import Dict, Any, List, Optional
11import pandas as pd
12
13logger = logging.getLogger(__name__)
14
15
16# v1 interface (backward compatibility)
17def execute(input_data: Dict[str, Any] = None) -> Dict[str, Any]:
18 '''
19 v1 interface: execute function for backward compatibility
20 '''
21 logger.info("executing basic model (v1 interface)")
22 return _execute_inference(input_data)
23
24
25# v2 interface (preferred)
26class Model:
27 '''
28 v2 model interface following model-lifecycle spec
29 '''
30
31 def train(self, ctx) -> Dict[str, Any]:
32 '''
33 train the model (optional - this is an inference-only model)
34 '''
35 ctx.logger.info("basic model does not require training")
36 return {
37 "status": "success",
38 "message": "no training required for basic model"
39 }
40
41 def infer(self, ctx) -> pd.DataFrame:
42 '''
43 run inference and return anomalies as dataframe
44 ctx should have: df (dataframe), params (dict), logger
45 '''
46 ctx.logger.info("executing basic model inference (v2 interface)")
47
48 # get data from context
49 df = ctx.df if hasattr(ctx, 'df') else None
50 params = ctx.params if hasattr(ctx, 'params') else {}
51
52 if df is None or len(df) == 0:
53 ctx.logger.warning("no data provided in context")
54 return pd.DataFrame(columns=["entity_id", "entity_type", "risk_score", "anomaly_type", "timestamp", "details"])
55
56 anomalies = []
57 row_count = len(df)
58 ctx.logger.info(f"processing {row_count} rows")
59
60 # basic anomaly detection: flag if row count is suspiciously high
61 threshold = params.get("threshold", 10000)
62 if row_count > threshold:
63 anomalies.append({
64 "entity_id": "system",
65 "entity_type": "data_volume",
66 "risk_score": 0.7,
67 "anomaly_type": "high_data_volume",
68 "timestamp": pd.Timestamp.now(),
69 "details": {
70 "row_count": row_count,
71 "threshold": threshold
72 }
73 })
74
75 # convert to dataframe
76 if anomalies:
77 return pd.DataFrame(anomalies)
78 else:
79 return pd.DataFrame(columns=["entity_id", "entity_type", "risk_score", "anomaly_type", "timestamp", "details"])
80
81
82def _execute_inference(input_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
83 '''
84 internal inference logic shared by v1 and v2 interfaces
85 '''
86 logger.info("executing basic model")
87
88 # get data source configuration from input
89 if not input_data:
90 input_data = {}
91
92 data_source = input_data.get("data_source") or input_data.get("type", "spark")
93 table_name = input_data.get("table_name")
94 index_name = input_data.get("index_name")
95 file_path = input_data.get("file_path")
96 file_name = input_data.get("file_name")
97
98 anomalies = []
99 df = None
100
101 try:
102 if data_source == "spark":
103 # use spark data adapter
104 from core.model_modules.spark.spark import SparkDataLoader
105
106 if not table_name:
107 raise ValueError("table_name required for spark data source")
108
109 loader = SparkDataLoader(table_name=table_name)
110 df = loader.data
111
112 # basic anomaly detection
113 row_count = len(df) if hasattr(df, '__len__') else 0
114 logger.info(f"loaded {row_count} rows from spark table: {table_name}")
115
116 # example: flag if row count is suspiciously high
117 threshold = input_data.get("threshold", 10000)
118 if row_count > threshold:
119 anomalies.append({
120 "entity_id": "system",
121 "entity_type": "data_volume",
122 "risk_score": 0.7,
123 "anomaly_type": "high_data_volume",
124 "details": {
125 "table_name": table_name,
126 "row_count": row_count,
127 "threshold": threshold
128 }
129 })
130
131 elif data_source == "elasticsearch":
132 # use elasticsearch data adapter
133 from core.model_modules.es.es import ESGeneric
134
135 if not index_name:
136 # try to infer from query or use default
137 index_name = input_data.get("index", "openuba-*")
138
139 # create query
140 query = input_data.get("query", {"match_all": {}})
141 es_host = input_data.get("host", os.getenv("ELASTICSEARCH_HOST", "http://elasticsearch:9200"))
142 es_loader = ESGeneric(host=es_host, query=query)
143 df = es_loader.data
144
145 # basic anomaly detection
146 if hasattr(df, 'shape') and len(df) > 0:
147 row_count = len(df)
148 logger.info(f"loaded {row_count} rows from elasticsearch")
149
150 # example anomaly detection
151 threshold = input_data.get("threshold", 10000)
152 if row_count > threshold:
153 anomalies.append({
154 "entity_id": "system",
155 "entity_type": "data_volume",
156 "risk_score": 0.7,
157 "anomaly_type": "high_data_volume",
158 "details": {
159 "index_name": index_name,
160 "document_count": row_count,
161 "threshold": threshold
162 }
163 })
164
165 elif data_source == "local_csv":
166 # use local pandas csv adapter
167 from core.model_modules.local_pandas.local_pandas import LocalPandasCSV
168
169 if not file_path or not file_name:
170 raise ValueError("file_path and file_name required for local_csv data source")
171
172 loader = LocalPandasCSV(
173 file_path=file_path,
174 file=file_name,
175 sep=input_data.get("sep", " "),
176 header=input_data.get("header", 0),
177 error_bad_lines=False,
178 warn_bad_lines=False
179 )
180 df = loader.data
181
182 if hasattr(df, 'shape') and len(df) > 0:
183 row_count = len(df)
184 logger.info(f"loaded {row_count} rows from local file: {file_name}")
185
186 else:
187 logger.warning(f"unknown data source: {data_source}")
188
189 except Exception as e:
190 logger.error(f"model execution failed: {e}")
191 import traceback
192 logger.error(traceback.format_exc())
193 # return empty result on error
194 return {
195 "anomalies": [],
196 "status": "error",
197 "error": str(e)
198 }
199
200 logger.info(f"model execution completed, found {len(anomalies)} anomalies")
201
202 return {
203 "anomalies": anomalies,
204 "status": "success",
205 "anomaly_count": len(anomalies),
206 "data_rows_processed": len(df) if df is not None else 0
207 }
208
209