001/**
002 * Copyright (C) 2007 - 2016, Jens Lehmann
003 *
004 * This file is part of DL-Learner.
005 *
006 * DL-Learner is free software; you can redistribute it and/or modify
007 * it under the terms of the GNU General Public License as published by
008 * the Free Software Foundation; either version 3 of the License, or
009 * (at your option) any later version.
010 *
011 * DL-Learner is distributed in the hope that it will be useful,
012 * but WITHOUT ANY WARRANTY; without even the implied warranty of
013 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
014 * GNU General Public License for more details.
015 *
016 * You should have received a copy of the GNU General Public License
017 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
018 */
019package org.dllearner.kb.sparql;
020
021import java.io.ByteArrayOutputStream;
022import java.io.StringReader;
023import java.io.UnsupportedEncodingException;
024import java.security.MessageDigest;
025import java.security.NoSuchAlgorithmException;
026import java.sql.Clob;
027import java.sql.Connection;
028import java.sql.DriverManager;
029import java.sql.PreparedStatement;
030import java.sql.ResultSet;
031import java.sql.SQLException;
032import java.sql.Statement;
033
034import org.dllearner.utilities.Helper;
035
036import org.apache.jena.query.ResultSetFactory;
037import org.apache.jena.query.ResultSetRewindable;
038import org.apache.jena.rdf.model.Model;
039import org.apache.jena.rdf.model.ModelFactory;
040import com.jamonapi.Monitor;
041import com.jamonapi.MonitorFactory;
042import org.apache.jena.sparql.engine.http.QueryEngineHTTP;
043
044/**
045 * The class is used to cache information about resources to a database.
046 * Provides the connection to an H2 database in a light weight, configuration free
047 * manner. 
048 * 
049 * Note: Currently, either select ot construct has to be used (not both).
050 * 
051 * @author Jens Lehmann
052 *
053 */
054public class ExtractionDBCache {
055
056        private String databaseDirectory = "cache";
057        private String databaseName = "extraction";
058        private boolean autoServerMode = true;
059        
060        // specifies after how many seconds a cached result becomes invalid
061        private long freshnessInMilliseconds = 15 * 24 * 60 * 60 * 1000; // 15 days     
062        
063        private int maxExecutionTimeInSeconds = 0;
064        
065        private Connection conn;
066        
067        MessageDigest md5;
068        
069        private Monitor mon = MonitorFactory.getTimeMonitor("Query");
070        
071        public ExtractionDBCache(Connection conn) throws SQLException
072        {
073                this.conn=conn;
074                try{md5 = MessageDigest.getInstance("MD5");}
075                catch(NoSuchAlgorithmException e) {throw new RuntimeException("Should never happen - MD5 not found.");}
076                // create cache table if it does not exist
077                Statement stmt = conn.createStatement();
078                stmt.execute("CREATE TABLE IF NOT EXISTS QUERY_CACHE(QUERYHASH BINARY PRIMARY KEY,QUERY VARCHAR(20000), TRIPLES CLOB, STORE_TIME TIMESTAMP)");          
079        }
080        
081        public String getCacheDirectory() {
082                return databaseDirectory;
083        }
084
085        public ExtractionDBCache(String cacheDir) {
086                databaseDirectory = cacheDir;
087                try {
088                        md5 = MessageDigest.getInstance("MD5");
089
090                        // load driver
091                        Class.forName("org.h2.Driver");
092
093                        String jdbcString = "";
094                        if (autoServerMode) {
095                                jdbcString = ";AUTO_SERVER=TRUE";
096                        }
097
098                        // connect to database (created automatically if not existing)
099                        conn = DriverManager.getConnection("jdbc:h2:" + databaseDirectory + "/" + databaseName + jdbcString, "sa", "pw");
100
101                        // create cache table if it does not exist
102                        try (Statement stmt = conn.createStatement()) {
103                                stmt.execute("CREATE TABLE IF NOT EXISTS QUERY_CACHE(QUERYHASH BINARY PRIMARY KEY,QUERY VARCHAR(20000), TRIPLES CLOB, STORE_TIME TIMESTAMP)");
104                        }
105                } catch (NoSuchAlgorithmException | ClassNotFoundException | SQLException e) {
106                        e.printStackTrace();
107                }
108        }
109        
110        public void setFreshnessInMilliseconds(long freshnessInMilliseconds) {
111                this.freshnessInMilliseconds = freshnessInMilliseconds;
112        }
113        
114        public Model executeConstructQuery(SparqlEndpoint endpoint, String query) throws SQLException, UnsupportedEncodingException {
115                return executeConstructQuery(endpoint, query, maxExecutionTimeInSeconds);
116        }
117        
118        public Model executeConstructQuery(SparqlEndpoint endpoint, String query, int maxExecutionTimeInSeconds) throws SQLException, UnsupportedEncodingException {
119                byte[] md5 = md5(query);                
120//              Timestamp currTS = new Timestamp(new java.util.Date().getTime());
121                PreparedStatement ps=conn.prepareStatement("SELECT * FROM QUERY_CACHE WHERE QUERYHASH=? LIMIT 1");
122                ps.setBytes(1, md5);
123                ResultSet rs = ps.executeQuery();
124                
125//              long startTime = System.nanoTime();
126                boolean entryExists = rs.next();
127                boolean readFromCache = entryExists && (System.currentTimeMillis() - rs.getTimestamp("STORE_TIME").getTime() < freshnessInMilliseconds);
128//              long runTime = System.nanoTime() - startTime;
129//              System.out.println(Helper.prettyPrintNanoSeconds(runTime, true, true));
130                
131                if(readFromCache) {
132//                      System.out.println("Reading from cache");
133//                      String posedQuery = rs.getString("QUERY");
134//                      System.out.println(posedQuery);
135                                
136                        Clob clob = rs.getClob("TRIPLES");
137//                      long startTime = System.nanoTime();
138                        Model readModel = ModelFactory.createDefaultModel();
139                        readModel.read(clob.getAsciiStream(), null, "N-TRIPLE");
140//                      long runTime = System.nanoTime() - startTime;
141//                      System.out.println(Helper.prettyPrintNanoSeconds(runTime, true, true));                 
142                        return readModel;
143                } else {
144                        mon.start();
145//                      System.out.println("Posing new query");
146                        
147//                      String endpoint = "http://139.18.2.37:8890/sparql";
148                        QueryEngineHTTP queryExecution = new QueryEngineHTTP(endpoint.getURL().toString(), query);
149                        queryExecution.setTimeout(maxExecutionTimeInSeconds * 1000);
150                        for (String dgu : endpoint.getDefaultGraphURIs()) {
151                                queryExecution.addDefaultGraph(dgu);
152                        }
153                        for (String ngu : endpoint.getNamedGraphURIs()) {
154                                queryExecution.addNamedGraph(ngu);
155                        }                       
156                        Model m2 = queryExecution.execConstruct();      
157                        
158                        // convert model to N-Triples
159                        ByteArrayOutputStream baos = new ByteArrayOutputStream();
160                        m2.write(baos, "N-TRIPLE");
161                        String modelStr = baos.toString("UTF-8");
162                        
163                        // use a prepared statement, so that Java handles all the escaping stuff correctly automatically
164                        PreparedStatement ps2;
165                        if(entryExists){
166                                ps2 = conn.prepareStatement("UPDATE QUERY_CACHE SET TRIPLES=?, STORE_TIME=? WHERE QUERYHASH=?");
167                                ps2.setClob(1, new StringReader(modelStr));
168                                ps2.setTimestamp(2, new java.sql.Timestamp(new java.util.Date().getTime()));
169                                ps2.setBytes(3, md5);
170                        } else {
171                                ps2 = conn.prepareStatement("INSERT INTO QUERY_CACHE VALUES(?,?,?,?)");
172                                ps2.setBytes(1, md5);
173                                ps2.setString(2, query);
174                                ps2.setClob(3, new StringReader(modelStr));
175                                ps2.setTimestamp(4, new java.sql.Timestamp(new java.util.Date().getTime()));
176                        }
177                        mon.stop();
178                        ps2.executeUpdate(); 
179                        
180                        return m2;
181                }
182        }
183        
184        public String executeSelectQuery(SparqlEndpoint endpoint, String query) {
185                return executeSelectQuery(endpoint, query, maxExecutionTimeInSeconds);
186        }
187        
188        public String executeSelectQuery(SparqlEndpoint endpoint, String query, int maxExecutionTimeInSeconds) {
189                
190                try {
191                        
192                        
193                byte[] md5 = md5(query);                
194                PreparedStatement ps=conn.prepareStatement("SELECT * FROM QUERY_CACHE WHERE QUERYHASH=? LIMIT 1");
195                ps.setBytes(1, md5);
196                ResultSet rs = ps.executeQuery();
197                
198                boolean entryExists = rs.next();
199                boolean readFromCache = entryExists && (System.currentTimeMillis() - rs.getTimestamp("STORE_TIME").getTime() < freshnessInMilliseconds);
200                
201                if(readFromCache) {
202//                      System.out.println("cache");
203                        Clob clob = rs.getClob("TRIPLES");
204                        return clob.getSubString(1, (int) clob.length());
205                } else {
206                        mon.start();
207//                      System.out.println("no-cache");
208                        QueryEngineHTTP queryExecution = new QueryEngineHTTP(endpoint.getURL().toString(), query);
209                        queryExecution.setTimeout(maxExecutionTimeInSeconds * 1000);
210                        for (String dgu : endpoint.getDefaultGraphURIs()) {
211                                queryExecution.addDefaultGraph(dgu);
212                        }
213                        for (String ngu : endpoint.getNamedGraphURIs()) {
214                                queryExecution.addNamedGraph(ngu);
215                        }
216                        org.apache.jena.query.ResultSet tmp = queryExecution.execSelect();
217                        ResultSetRewindable rs2 = ResultSetFactory.makeRewindable(tmp);
218                        String json = SparqlQuery.convertResultSetToJSON(rs2);
219                        
220                        // use a prepared statement, so that Java handles all the escaping stuff correctly automatically
221                        PreparedStatement ps2;
222                        if(entryExists){
223                                ps2 = conn.prepareStatement("UPDATE QUERY_CACHE SET TRIPLES=?, STORE_TIME=? WHERE QUERYHASH=?");
224                                ps2.setClob(1, new StringReader(json));
225                                ps2.setTimestamp(2, new java.sql.Timestamp(new java.util.Date().getTime()));
226                                ps2.setBytes(3, md5);
227                        } else {
228                                ps2 = conn.prepareStatement("INSERT INTO QUERY_CACHE VALUES(?,?,?,?)");
229                                ps2.setBytes(1, md5);
230                                ps2.setString(2, query);
231                                ps2.setClob(3, new StringReader(json));
232                                ps2.setTimestamp(4, new java.sql.Timestamp(new java.util.Date().getTime()));
233                        }
234                        mon.stop();
235                        ps2.executeUpdate(); 
236                        return json;
237                }
238                } catch(SQLException e) {
239                        e.printStackTrace();
240                        return null;
241                } finally{
242                        mon.stop();
243                }
244        }       
245        
246        public void closeConnection() throws SQLException {
247                conn.close();
248        }
249        
250        private synchronized byte[] md5(String string) {
251                md5.reset();
252                md5.update(string.getBytes());
253                return md5.digest();
254        }
255        
256        public static String toNTriple(Model m) {
257                ByteArrayOutputStream baos = new ByteArrayOutputStream();
258                m.write(baos, "N-TRIPLE");
259                try {
260                        return baos.toString("UTF-8");
261                } catch (UnsupportedEncodingException e) {
262                        e.printStackTrace();
263                        return null;
264                }
265        }
266        
267        public void setMaxExecutionTimeInSeconds(int maxExecutionTimeInSeconds){
268                this.maxExecutionTimeInSeconds = maxExecutionTimeInSeconds;
269        }
270        
271        public static void main(String[] args) throws ClassNotFoundException, SQLException, NoSuchAlgorithmException, UnsupportedEncodingException {
272                SparqlEndpoint endpoint = SparqlEndpoint.getEndpointDBpediaLiveAKSW();
273                String resource = "http://dbpedia.org/resource/Leipzig";
274                String query = "CONSTRUCT { <"+resource+"> ?p ?o } WHERE { <"+resource+"> ?p ?o }"; 
275                System.out.println("query: " + query);
276                
277                ExtractionDBCache h2 = new ExtractionDBCache("cache"); 
278                long startTime = System.nanoTime();
279                Model m = h2.executeConstructQuery(endpoint, query);
280//              for(int i=0; i<1000; i++) {
281//                      h2.executeConstructQuery(endpoint, query);
282//              }
283                long runTime = System.nanoTime() - startTime;
284                System.out.println("Answer obtained in " + Helper.prettyPrintNanoSeconds(runTime));
285                System.out.println(ExtractionDBCache.toNTriple(m));
286        }       
287
288}