Back to Models

basic_model

v1.0.0
Python

Basic model demonstrating Spark, Elasticsearch, and local CSV data adapters. Supports both v1 (execute) and v2 (train/infer) interfaces for backward compatibility.

$ openuba install basic_model
OpenUBA
python-base
License: Apache-2.0
data-adaptersparkelasticsearchcsvmulti-sourcev1-compatible

Parameters

NameTypeDefaultDescription
thresholdinteger10000Row count threshold for anomaly detection
data_sourcestringsparkData source type(spark, elasticsearch, local_csv)
model.yaml
1name: basic_model
2version: 1.0.0
3runtime: python-base
4description: Basic model demonstrating Spark/ES adapters
5parameters:
6  threshold:
7    type: integer
8    default: 10000
9    description: Row count threshold for anomaly detection
10  data_source:
11    type: string
12    default: spark
13    description: Data source type (spark, elasticsearch, local_csv)
14    enum: [spark, elasticsearch, local_csv]
15
MODEL.py
1'''
2Copyright 2019-Present The OpenUBA Platform Authors
3basic model template using data adapters
4this model demonstrates how to use spark and elasticsearch data adapters
5supports both v1 (execute) and v2 (train/infer) interfaces
6'''
7
8import os
9import logging
10from typing import Dict, Any, List, Optional
11import pandas as pd
12
13logger = logging.getLogger(__name__)
14
15
16# v1 interface (backward compatibility)
17def execute(input_data: Dict[str, Any] = None) -> Dict[str, Any]:
18    '''
19    v1 interface: execute function for backward compatibility
20    '''
21    logger.info("executing basic model (v1 interface)")
22    return _execute_inference(input_data)
23
24
25# v2 interface (preferred)
26class Model:
27    '''
28    v2 model interface following model-lifecycle spec
29    '''
30    
31    def train(self, ctx) -> Dict[str, Any]:
32        '''
33        train the model (optional - this is an inference-only model)
34        '''
35        ctx.logger.info("basic model does not require training")
36        return {
37            "status": "success",
38            "message": "no training required for basic model"
39        }
40
41    def infer(self, ctx) -> pd.DataFrame:
42        '''
43        run inference and return anomalies as dataframe
44        ctx should have: df (dataframe), params (dict), logger
45        '''
46        ctx.logger.info("executing basic model inference (v2 interface)")
47
48        # get data from context
49        df = ctx.df if hasattr(ctx, 'df') else None
50        params = ctx.params if hasattr(ctx, 'params') else {}
51
52        if df is None or len(df) == 0:
53            ctx.logger.warning("no data provided in context")
54            return pd.DataFrame(columns=["entity_id", "entity_type", "risk_score", "anomaly_type", "timestamp", "details"])
55
56        anomalies = []
57        row_count = len(df)
58        ctx.logger.info(f"processing {row_count} rows")
59        
60        # basic anomaly detection: flag if row count is suspiciously high
61        threshold = params.get("threshold", 10000)
62        if row_count > threshold:
63            anomalies.append({
64                "entity_id": "system",
65                "entity_type": "data_volume",
66                "risk_score": 0.7,
67                "anomaly_type": "high_data_volume",
68                "timestamp": pd.Timestamp.now(),
69                "details": {
70                    "row_count": row_count,
71                    "threshold": threshold
72                }
73            })
74        
75        # convert to dataframe
76        if anomalies:
77            return pd.DataFrame(anomalies)
78        else:
79            return pd.DataFrame(columns=["entity_id", "entity_type", "risk_score", "anomaly_type", "timestamp", "details"])
80
81
82def _execute_inference(input_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
83    '''
84    internal inference logic shared by v1 and v2 interfaces
85    '''
86    logger.info("executing basic model")
87    
88    # get data source configuration from input
89    if not input_data:
90        input_data = {}
91    
92    data_source = input_data.get("data_source") or input_data.get("type", "spark")
93    table_name = input_data.get("table_name")
94    index_name = input_data.get("index_name")
95    file_path = input_data.get("file_path")
96    file_name = input_data.get("file_name")
97    
98    anomalies = []
99    df = None
100    
101    try:
102        if data_source == "spark":
103            # use spark data adapter
104            from core.model_modules.spark.spark import SparkDataLoader
105            
106            if not table_name:
107                raise ValueError("table_name required for spark data source")
108            
109            loader = SparkDataLoader(table_name=table_name)
110            df = loader.data
111            
112            # basic anomaly detection
113            row_count = len(df) if hasattr(df, '__len__') else 0
114            logger.info(f"loaded {row_count} rows from spark table: {table_name}")
115            
116            # example: flag if row count is suspiciously high
117            threshold = input_data.get("threshold", 10000)
118            if row_count > threshold:
119                anomalies.append({
120                    "entity_id": "system",
121                    "entity_type": "data_volume",
122                    "risk_score": 0.7,
123                    "anomaly_type": "high_data_volume",
124                    "details": {
125                        "table_name": table_name,
126                        "row_count": row_count,
127                        "threshold": threshold
128                    }
129                })
130        
131        elif data_source == "elasticsearch":
132            # use elasticsearch data adapter
133            from core.model_modules.es.es import ESGeneric
134            
135            if not index_name:
136                # try to infer from query or use default
137                index_name = input_data.get("index", "openuba-*")
138            
139            # create query
140            query = input_data.get("query", {"match_all": {}})
141            es_host = input_data.get("host", os.getenv("ELASTICSEARCH_HOST", "http://elasticsearch:9200"))
142            es_loader = ESGeneric(host=es_host, query=query)
143            df = es_loader.data
144            
145            # basic anomaly detection
146            if hasattr(df, 'shape') and len(df) > 0:
147                row_count = len(df)
148                logger.info(f"loaded {row_count} rows from elasticsearch")
149                
150                # example anomaly detection
151                threshold = input_data.get("threshold", 10000)
152                if row_count > threshold:
153                    anomalies.append({
154                        "entity_id": "system",
155                        "entity_type": "data_volume",
156                        "risk_score": 0.7,
157                        "anomaly_type": "high_data_volume",
158                        "details": {
159                            "index_name": index_name,
160                            "document_count": row_count,
161                            "threshold": threshold
162                        }
163                    })
164        
165        elif data_source == "local_csv":
166            # use local pandas csv adapter
167            from core.model_modules.local_pandas.local_pandas import LocalPandasCSV
168            
169            if not file_path or not file_name:
170                raise ValueError("file_path and file_name required for local_csv data source")
171            
172            loader = LocalPandasCSV(
173                file_path=file_path,
174                file=file_name,
175                sep=input_data.get("sep", " "),
176                header=input_data.get("header", 0),
177                error_bad_lines=False,
178                warn_bad_lines=False
179            )
180            df = loader.data
181            
182            if hasattr(df, 'shape') and len(df) > 0:
183                row_count = len(df)
184                logger.info(f"loaded {row_count} rows from local file: {file_name}")
185        
186        else:
187            logger.warning(f"unknown data source: {data_source}")
188    
189    except Exception as e:
190        logger.error(f"model execution failed: {e}")
191        import traceback
192        logger.error(traceback.format_exc())
193        # return empty result on error
194        return {
195            "anomalies": [],
196            "status": "error",
197            "error": str(e)
198        }
199    
200    logger.info(f"model execution completed, found {len(anomalies)} anomalies")
201    
202    return {
203        "anomalies": anomalies,
204        "status": "success",
205        "anomaly_count": len(anomalies),
206        "data_rows_processed": len(df) if df is not None else 0
207    }
208
209